diff --git a/sed/config/flash_example_config.yaml b/sed/config/flash_example_config.yaml index 6bb05521..d294b752 100644 --- a/sed/config/flash_example_config.yaml +++ b/sed/config/flash_example_config.yaml @@ -9,6 +9,8 @@ core: beamtime_id: 11013410 # the year of the beamtime year: 2023 + # the instrument used + instrument: hextof # hextof, wespe, etc # The paths to the raw and parquet data directories. If these are not # provided, the loader will try to find the data based on year beamtimeID etc @@ -52,18 +54,20 @@ dataframe: tof_ns_column: dldTime # dataframe column containing corrected time-of-flight data corrected_tof_column: "tm" + # the time stamp column + time_stamp_alias: timeStamp # time length of a base time-of-flight bin in seconds tof_binwidth: 2.0576131995767355E-11 # binning parameter for time-of-flight data. 2**tof_binning bins per base bin tof_binning: 3 # power of 2, 3 means 8 bins per step # dataframe column containing sector ID. obtained from dldTimeSteps column sector_id_column: dldSectorID - sector_delays: [0., 0., 0., 0., 0., 0., 0., 0.] # the delay stage column delay_column: delayStage # the corrected pump-probe time axis corrected_delay_column: pumpProbeTime + # the columns to be used for jitter correction jitter_cols: ["dldPosX", "dldPosY", "dldTimeSteps"] units: @@ -95,24 +99,28 @@ dataframe: # The timestamp timeStamp: format: per_train - group_name: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/" + index_key: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/index" + dataset_key: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/time" # pulse ID is a necessary channel for using the loader. pulseId: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 2 # detector x position dldPosX: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 1 # detector y position dldPosY: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 0 # Detector time-of-flight channel @@ -120,14 +128,16 @@ dataframe: # also the dldSectorID channel dldTimeSteps: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 3 # The auxillary channel has a special structure where the group further contains # a multidimensional structure so further aliases are defined below dldAux: format: per_pulse - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 4 dldAuxChannels: sampleBias: 0 @@ -141,29 +151,35 @@ dataframe: # ADC containing the pulser sign (1: value approx. 35000, 0: 33000) pulserSignAdc: format: per_pulse - group_name: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/" + index_key: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/index" + dataset_key: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/value" # the energy of the monochromatized beam. This is a quasi-static value. # there is a better channel which still needs implementation. monochromatorPhotonEnergy: format: per_train - group_name: "/FL1/Beamlines/PG/Monochromator/monochromator photon energy/" + index_key: "/FL1/Beamlines/PG/Monochromator/monochromator photon energy/index" + dataset_key: "/FL1/Beamlines/PG/Monochromator/monochromator photon energy/value" # The GMDs can not be read yet... gmdBda: format: per_train - group_name: "/FL1/Photon Diagnostic/GMD/Average energy/energy BDA/" + index_key: "/FL1/Photon Diagnostic/GMD/Average energy/energy BDA/index" + dataset_key: "/FL1/Photon Diagnostic/GMD/Average energy/energy BDA/value" + # Beam Arrival Monitor, vital for pump-probe experiments as it can compensate sase # timing fluctuations. # Here we use the DBC2 BAM as the "normal" one is broken. bam: format: per_pulse - group_name: "/uncategorised/FLASH.SDIAG/BAM.DAQ/FL0.DBC2.ARRIVAL_TIME.ABSOLUTE.SA1.COMP/" + index_key: "/uncategorised/FLASH.SDIAG/BAM.DAQ/FL0.DBC2.ARRIVAL_TIME.ABSOLUTE.SA1.COMP/index" + dataset_key: "/uncategorised/FLASH.SDIAG/BAM.DAQ/FL0.DBC2.ARRIVAL_TIME.ABSOLUTE.SA1.COMP/value" # The delay Stage position, encoding the pump-probe delay delayStage: format: per_train - group_name: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/" + index_key: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/index" + dataset_key: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/value" # The prefixes of the stream names for different DAQ systems for parsing filenames # (Not to be changed by user) diff --git a/sed/core/metadata.py b/sed/core/metadata.py index d38fe313..d930d2a3 100644 --- a/sed/core/metadata.py +++ b/sed/core/metadata.py @@ -57,6 +57,8 @@ def _format_attributes(self, attributes: dict, indent: int = 0) -> str: INDENT_FACTOR = 20 html = "" for key, value in attributes.items(): + # Ensure the key is a string + key = str(key) # Format key formatted_key = key.replace("_", " ").title() formatted_key = f"{formatted_key}" diff --git a/sed/loader/flash/buffer_handler.py b/sed/loader/flash/buffer_handler.py new file mode 100644 index 00000000..7c66f8f2 --- /dev/null +++ b/sed/loader/flash/buffer_handler.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import os +from itertools import compress +from pathlib import Path + +import dask.dataframe as dd +import pyarrow.parquet as pq +from joblib import delayed +from joblib import Parallel + +from sed.core.dfops import forward_fill_lazy +from sed.loader.flash.dataframe import DataFrameCreator +from sed.loader.flash.utils import get_channels +from sed.loader.flash.utils import initialize_paths +from sed.loader.utils import get_parquet_metadata +from sed.loader.utils import split_dld_time_from_sector_id + + +class BufferHandler: + """ + A class for handling the creation and manipulation of buffer files using DataFrameCreator. + """ + + def __init__( + self, + config: dict, + ) -> None: + """ + Initializes the BufferHandler. + + Args: + config (dict): The configuration dictionary. + """ + self._config = config["dataframe"] + self.n_cores = config["core"].get("num_cores", os.cpu_count() - 1) + + self.buffer_paths: list[Path] = [] + self.missing_h5_files: list[Path] = [] + self.save_paths: list[Path] = [] + + self.df_electron: dd.DataFrame = None + self.df_pulse: dd.DataFrame = None + self.metadata: dict = {} + + def _schema_check(self) -> None: + """ + Checks the schema of the Parquet files. + + Raises: + ValueError: If the schema of the Parquet files does not match the configuration. + """ + existing_parquet_filenames = [file for file in self.buffer_paths if file.exists()] + parquet_schemas = [pq.read_schema(file) for file in existing_parquet_filenames] + config_schema_set = set( + get_channels(self._config["channels"], formats="all", index=True, extend_aux=True), + ) + + for filename, schema in zip(existing_parquet_filenames, parquet_schemas): + # for retro compatibility when sectorID was also saved in buffer + if self._config["sector_id_column"] in schema.names: + config_schema_set.add( + self._config["sector_id_column"], + ) + schema_set = set(schema.names) + if schema_set != config_schema_set: + missing_in_parquet = config_schema_set - schema_set + missing_in_config = schema_set - config_schema_set + + errors = [] + if missing_in_parquet: + errors.append(f"Missing in parquet: {missing_in_parquet}") + if missing_in_config: + errors.append(f"Missing in config: {missing_in_config}") + + raise ValueError( + f"The available channels do not match the schema of file {filename}. " + f"{' '.join(errors)}. " + "Please check the configuration file or set force_recreate to True.", + ) + + def _get_files_to_read( + self, + h5_paths: list[Path], + folder: Path, + prefix: str, + suffix: str, + force_recreate: bool, + ) -> None: + """ + Determines the list of files to read and the corresponding buffer files to create. + + Args: + h5_paths (List[Path]): List of paths to H5 files. + folder (Path): Path to the folder for buffer files. + prefix (str): Prefix for buffer file names. + suffix (str): Suffix for buffer file names. + force_recreate (bool): Flag to force recreation of buffer files. + """ + # Getting the paths of the buffer files, with subfolder as buffer and no extension + self.buffer_paths = initialize_paths( + filenames=[h5_path.stem for h5_path in h5_paths], + folder=folder, + subfolder="buffer", + prefix=prefix, + suffix=suffix, + extension="", + ) + # read only the files that do not exist or if force_recreate is True + files_to_read = [ + force_recreate or not parquet_path.exists() for parquet_path in self.buffer_paths + ] + + # Get the list of H5 files to read and the corresponding buffer files to create + self.missing_h5_files = list(compress(h5_paths, files_to_read)) + self.save_paths = list(compress(self.buffer_paths, files_to_read)) + + print(f"Reading files: {len(self.missing_h5_files)} new files of {len(h5_paths)} total.") + + def _save_buffer_file(self, h5_path: Path, parquet_path: Path) -> None: + """ + Creates a single buffer file. + + Args: + h5_path (Path): Path to the H5 file. + parquet_path (Path): Path to the buffer file. + """ + + # Create a DataFrameCreator instance and the h5 file + df = DataFrameCreator(config_dataframe=self._config, h5_path=h5_path).df + + # Reset the index of the DataFrame and save it as a parquet file + df.reset_index().to_parquet(parquet_path) + + def _save_buffer_files(self, debug: bool) -> None: + """ + Creates the buffer files. + + Args: + debug (bool): Flag to enable debug mode, which serializes the creation. + """ + n_cores = min(len(self.missing_h5_files), self.n_cores) + paths = zip(self.missing_h5_files, self.save_paths) + if n_cores > 0: + if debug: + for h5_path, parquet_path in paths: + self._save_buffer_file(h5_path, parquet_path) + else: + Parallel(n_jobs=n_cores, verbose=10)( + delayed(self._save_buffer_file)(h5_path, parquet_path) + for h5_path, parquet_path in paths + ) + + def _fill_dataframes(self): + """ + Reads all parquet files into one dataframe using dask and fills NaN values. + """ + dataframe = dd.read_parquet(self.buffer_paths, calculate_divisions=True) + file_metadata = get_parquet_metadata( + self.buffer_paths, + time_stamp_col=self._config.get("time_stamp_alias", "timeStamp"), + ) + self.metadata["file_statistics"] = file_metadata + + fill_channels: list[str] = get_channels( + self._config["channels"], + ["per_pulse", "per_train"], + extend_aux=True, + ) + index: list[str] = get_channels(index=True) + overlap = min(file["num_rows"] for file in file_metadata.values()) + + dataframe = forward_fill_lazy( + df=dataframe, + columns=fill_channels, + before=overlap, + iterations=self._config.get("forward_fill_iterations", 2), + ) + self.metadata["forward_fill"] = { + "columns": fill_channels, + "overlap": overlap, + "iterations": self._config.get("forward_fill_iterations", 2), + } + + # Drop rows with nan values in electron channels + df_electron = dataframe.dropna( + subset=get_channels(self._config["channels"], ["per_electron"]), + ) + + # Set the dtypes of the channels here as there should be no null values + channel_dtypes = get_channels(self._config["channels"], "all") + config_channels = self._config["channels"] + dtypes = { + channel: config_channels[channel].get("dtype") + for channel in channel_dtypes + if config_channels[channel].get("dtype") is not None + } + + # Correct the 3-bit shift which encodes the detector ID in the 8s time + if self._config.get("split_sector_id_from_dld_time", False): + df_electron, meta = split_dld_time_from_sector_id( + df_electron, + config=self._config, + ) + self.metadata.update(meta) + + self.df_electron = df_electron.astype(dtypes) + self.df_pulse = dataframe[index + fill_channels] + + def run( + self, + h5_paths: list[Path], + folder: Path, + force_recreate: bool = False, + prefix: str = "", + suffix: str = "", + debug: bool = False, + ) -> None: + """ + Runs the buffer file creation process. + + Args: + h5_paths (List[Path]): List of paths to H5 files. + folder (Path): Path to the folder for buffer files. + force_recreate (bool): Flag to force recreation of buffer files. + prefix (str): Prefix for buffer file names. + suffix (str): Suffix for buffer file names. + debug (bool): Flag to enable debug mode.): + """ + + self._get_files_to_read(h5_paths, folder, prefix, suffix, force_recreate) + + if not force_recreate: + self._schema_check() + + self._save_buffer_files(debug) + + self._fill_dataframes() diff --git a/sed/loader/flash/dataframe.py b/sed/loader/flash/dataframe.py new file mode 100644 index 00000000..0067279e --- /dev/null +++ b/sed/loader/flash/dataframe.py @@ -0,0 +1,294 @@ +""" +This module creates pandas DataFrames from HDF5 files for different levels of data granularity +[per electron, per pulse, and per train]. It efficiently handles concatenation of data from +various channels within the HDF5 file, making use of the structured nature data to optimize +join operations. This approach significantly enhances performance compared to earlier. +""" +from __future__ import annotations + +from pathlib import Path + +import h5py +import numpy as np +import pandas as pd + +from sed.loader.flash.utils import get_channels + + +class DataFrameCreator: + """ + A class for creating pandas DataFrames from an HDF5 file. + + Attributes: + h5_file (h5py.File): The HDF5 file object. + multi_index (pd.MultiIndex): The multi-index structure for the DataFrame. + _config (dict): The configuration dictionary for the DataFrame. + """ + + def __init__(self, config_dataframe: dict, h5_path: Path) -> None: + """ + Initializes the DataFrameCreator class. + + Args: + config_dataframe (dict): The configuration dictionary with only the dataframe key. + h5_path (Path): Path to the h5 file. + """ + self.h5_file = h5py.File(h5_path, "r") + self.multi_index = get_channels(index=True) + self._config = config_dataframe + + def get_index_dataset_key(self, channel: str) -> tuple[str, str]: + """ + Checks if 'index_key' and 'dataset_key' exists and returns that. + + Args: + channel (str): The name of the channel. + + Returns: + tuple[str, str]: Outputs a tuple of 'index_key' and 'dataset_key'. + + Raises: + ValueError: If 'index_key' and 'dataset_key' are not provided. + """ + channel_config = self._config["channels"][channel] + + if "index_key" in channel_config and "dataset_key" in channel_config: + return channel_config["index_key"], channel_config["dataset_key"] + elif "group_name" in channel_config: + print("'group_name' is no longer supported.") + + raise ValueError( + "For channel:", + channel, + "Provide both 'index_key' and 'dataset_key'.", + ) + + def get_dataset_array( + self, + channel: str, + slice_: bool = False, + ) -> tuple[pd.Index, h5py.Dataset]: + """ + Returns a numpy array for a given channel name. + + Args: + channel (str): The name of the channel. + slice_ (bool): If True, applies slicing on the dataset. + + Returns: + tuple[pd.Index, h5py.Dataset]: A tuple containing the train ID + pd.Index and the numpy array for the channel's data. + """ + # Get the data from the necessary h5 file and channel + index_key, dataset_key = self.get_index_dataset_key(channel) + + key = pd.Index(self.h5_file[index_key], name="trainId") # macrobunch + dataset = self.h5_file[dataset_key] + + if slice_: + slice_index = self._config["channels"][channel].get("slice", None) + if slice_index is not None: + dataset = np.take(dataset, slice_index, axis=1) + # If np_array is size zero, fill with NaNs + if dataset.shape[0] == 0: + # Fill the np_array with NaN values of the same shape as train_id + dataset = np.full_like(key, np.nan, dtype=np.double) + + return key, dataset + + def pulse_index(self, offset: int) -> tuple[pd.MultiIndex, slice | np.ndarray]: + """ + Creates a multi-level index that combines train IDs and pulse IDs, and handles + sorting and electron counting within each pulse. + + Args: + offset (int): The offset value. + + Returns: + tuple[pd.MultiIndex, np.ndarray]: A tuple containing the computed pd.MultiIndex and + the indexer. + """ + # Get the pulse_dataset and the train_index + train_index, pulse_dataset = self.get_dataset_array("pulseId", slice_=True) + # pulse_dataset comes as a 2D array, resolved per train. Here it is flattened + # the daq has an offset so no pulses are missed. This offset is subtracted here + pulse_ravel = pulse_dataset.ravel() - offset + # Here train_index is repeated to match the size of pulses + train_index_repeated = np.repeat(train_index, pulse_dataset.shape[1]) + # A pulse resolved multi-index is finally created. + # Since there can be NaN pulses, those are dropped + pulse_index = pd.MultiIndex.from_arrays((train_index_repeated, pulse_ravel)).dropna() + + # Sometimes the pulse_index are not monotonic, so we might need to sort them + # The indexer is also returned to sort the data in df_electron + indexer = slice(None) + if not pulse_index.is_monotonic_increasing: + pulse_index, indexer = pulse_index.sort_values(return_indexer=True) + + # In the data, to signify different electrons, pulse_index is repeated by + # the number of electrons in each pulse. Here the values are counted + electron_counts = pulse_index.value_counts(sort=False).values + # Now we resolve each pulse to its electrons + electron_index = np.concatenate([np.arange(count) for count in electron_counts]) + + # Final multi-index constructed here + index = pd.MultiIndex.from_arrays( + ( + pulse_index.get_level_values(0), + pulse_index.get_level_values(1).astype(int), + electron_index, + ), + names=self.multi_index, + ) + return index, indexer + + @property + def df_electron(self) -> pd.DataFrame: + """ + Returns a pandas DataFrame for channel names of type [per electron]. + + Returns: + pd.DataFrame: The pandas DataFrame for the 'per_electron' channel's data. + """ + offset = self._config.get("ubid_offset", 5) # 5 is the default value + # Here we get the multi-index and the indexer to sort the data + index, indexer = self.pulse_index(offset) + + # Get the relevant channels and their slice index + channels = get_channels(self._config["channels"], "per_electron") + slice_index = [self._config["channels"][channel].get("slice", None) for channel in channels] + + # First checking if dataset keys are the same for all channels + # because DLD at FLASH stores all channels in the same h5 dataset + dataset_keys = [self.get_index_dataset_key(channel)[1] for channel in channels] + # Gives a true if all keys are the same + all_keys_same = all(key == dataset_keys[0] for key in dataset_keys) + + # If all dataset keys are the same, we only need to load the dataset once and slice + # the appropriate columns. This is much faster than loading the same dataset multiple times + if all_keys_same: + _, dataset = self.get_dataset_array(channels[0]) + data_dict = { + channel: dataset[:, slice_, :].ravel() + for channel, slice_ in zip(channels, slice_index) + } + dataframe = pd.DataFrame(data_dict) + # In case channels do differ, we create a pd.Series for each channel and concatenate them + else: + series = { + channel: pd.Series(self.get_dataset_array(channel, slice_=True)[1].ravel()) + for channel in channels + } + dataframe = pd.concat(series, axis=1) + + # after offset, the negative pulse values are dropped as they are not valid + drop_vals = np.arange(-offset, 0) + + # Few things happen here: + # Drop all NaN values like while creating the multiindex + # if necessary, the data is sorted with [indexer] + # pd.MultiIndex is set + # Finally, the offset values are dropped + return ( + dataframe.dropna() + .iloc[indexer] + .set_index(index) + .drop(index=drop_vals, level="pulseId", errors="ignore") + ) + + @property + def df_pulse(self) -> pd.DataFrame: + """ + Returns a pandas DataFrame for given channel names of type [per pulse]. + + Returns: + pd.DataFrame: The pandas DataFrame for the 'per_pulse' channel's data. + """ + series = [] + # Get the relevant channel names + channels = get_channels(self._config["channels"], "per_pulse") + # For each channel, a pd.Series is created and appended to the list + for channel in channels: + # train_index and (sliced) data is returned + key, dataset = self.get_dataset_array(channel, slice_=True) + # Electron resolved MultiIndex is created. Since this is pulse data, + # the electron index is always 0 + index = pd.MultiIndex.from_product( + (key, np.arange(0, dataset.shape[1]), [0]), + names=self.multi_index, + ) + # The dataset is opened and converted to numpy array by [()] + # and flattened to resolve per pulse + # The pd.Series is created with the MultiIndex and appended to the list + series.append(pd.Series(dataset[()].ravel(), index=index, name=channel)) + + # All the channels are concatenated to a single DataFrame + return pd.concat( + series, + axis=1, + ) + + @property + def df_train(self) -> pd.DataFrame: + """ + Returns a pandas DataFrame for given channel names of type [per train]. + + Returns: + pd.DataFrame: The pandas DataFrame for the 'per_train' channel's data. + """ + series = [] + # Get the relevant channel names + channels = get_channels(self._config["channels"], "per_train") + # For each channel, a pd.Series is created and appended to the list + for channel in channels: + # train_index and (sliced) data is returned + key, dataset = self.get_dataset_array(channel, slice_=True) + # Electron and pulse resolved MultiIndex is created. Since this is train data, + # the electron and pulse index is always 0 + index = pd.MultiIndex.from_product( + (key, [0], [0]), + names=self.multi_index, + ) + # Auxillary dataset (which is stored in the same dataset as other DLD channels) + # contains multiple channels inside. Even though they are resolved per train, + # they come in pulse format, so the extra values are sliced and individual channels are + # created and appended to the list + if channel == "dldAux": + aux_channels = self._config["channels"]["dldAux"]["dldAuxChannels"].items() + for name, slice_aux in aux_channels: + series.append(pd.Series(dataset[: key.size, slice_aux], index, name=name)) + else: + series.append(pd.Series(dataset, index, name=channel)) + # All the channels are concatenated to a single DataFrame + return pd.concat(series, axis=1) + + def validate_channel_keys(self) -> None: + """ + Validates if the index and dataset keys for all channels in config exist in the h5 file. + + Raises: + KeyError: If the index or dataset keys do not exist in the file. + """ + for channel in self._config["channels"]: + index_key, dataset_key = self.get_index_dataset_key(channel) + if index_key not in self.h5_file: + raise KeyError(f"pd.Index key '{index_key}' doesn't exist in the file.") + if dataset_key not in self.h5_file: + raise KeyError(f"Dataset key '{dataset_key}' doesn't exist in the file.") + + @property + def df(self) -> pd.DataFrame: + """ + Joins the 'per_electron', 'per_pulse', and 'per_train' using join operation, + returning a single dataframe. + + Returns: + pd.DataFrame: The combined pandas DataFrame. + """ + + self.validate_channel_keys() + return ( + self.df_electron.join(self.df_pulse, on=self.multi_index, how="outer") + .join(self.df_train, on=self.multi_index, how="outer") + .sort_index() + ) diff --git a/sed/loader/flash/instruments.py b/sed/loader/flash/instruments.py new file mode 100644 index 00000000..8ef0146e --- /dev/null +++ b/sed/loader/flash/instruments.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from dask import dataframe as dd + + +def wespe_convert(df: dd.DataFrame, df_timed: dd.DataFrame) -> tuple[dd.DataFrame, dd.DataFrame]: + df + df_timed + raise NotImplementedError("This function is not implemented yet.") diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 6402f0cc..01d2aa62 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -2,34 +2,25 @@ This module implements the flash data loader. This loader currently supports hextof, wespe and instruments with similar structure. The raw hdf5 data is combined and saved into buffer files and loaded as a dask dataframe. -The dataframe is a amalgamation of all h5 files for a combination of runs, where the NaNs are -automatically forward filled across different files. +The dataframe is an amalgamation of all h5 files for a combination of runs, where the NaNs are +automatically forward-filled across different files. This can then be saved as a parquet for out-of-sed processing and reread back to access other sed functionality. """ from __future__ import annotations +import re import time from collections.abc import Sequence -from functools import reduce from pathlib import Path import dask.dataframe as dd -import h5py -import numpy as np -import pyarrow.parquet as pq -from joblib import delayed -from joblib import Parallel from natsort import natsorted -from pandas import DataFrame -from pandas import MultiIndex -from pandas import Series -from sed.core import dfops from sed.loader.base.loader import BaseLoader +from sed.loader.flash.buffer_handler import BufferHandler +from sed.loader.flash.instruments import wespe_convert from sed.loader.flash.metadata import MetadataRetriever -from sed.loader.utils import parse_h5_keys -from sed.loader.utils import split_dld_time_from_sector_id class FlashLoader(BaseLoader): @@ -44,19 +35,24 @@ class FlashLoader(BaseLoader): supported_file_types = ["h5"] def __init__(self, config: dict) -> None: - super().__init__(config=config) - self.multi_index = ["trainId", "pulseId", "electronId"] - self.index_per_electron: MultiIndex = None - self.index_per_pulse: MultiIndex = None - self.failed_files_error: list[str] = [] + """ + Initializes the FlashLoader. - def initialize_paths(self) -> tuple[list[Path], Path]: + Args: + config (dict): Configuration dictionary. """ - Initializes the paths based on the configuration. + super().__init__(config=config) + self.instrument: str = self._config["core"].get("instrument", "hextof") # default is hextof + self.raw_dir: str = None + self.parquet_dir: str = None - Returns: - tuple[list[Path], Path]: A tuple containing a list of raw data directories - paths and the parquet data directory path. + def _initialize_dirs(self) -> None: + """ + Initializes the directories on Maxwell based on configuration. If paths is provided in + the configuration, the raw data directory and parquet data directory are taken from there. + Otherwise, the beamtime_id and year are used to locate the data directories. + The first path that has either online- or express- prefix, or the daq name is taken as the + raw data directory. Raises: ValueError: If required values are missing from the configuration. @@ -75,10 +71,10 @@ def initialize_paths(self) -> tuple[list[Path], Path]: try: beamtime_id = self._config["core"]["beamtime_id"] year = self._config["core"]["year"] - daq = self._config["dataframe"]["daq"] + except KeyError as exc: raise ValueError( - "The beamtime_id, year and daq are required.", + "The beamtime_id and year are required.", ) from exc beamtime_dir = Path( @@ -93,11 +89,9 @@ def initialize_paths(self) -> tuple[list[Path], Path]: for path in raw_path.glob("**/*"): if path.is_dir(): dir_name = path.name - if dir_name.startswith("express-") or dir_name.startswith( - "online-", - ): - data_raw_dir.append(path.joinpath(daq)) - elif dir_name == daq.upper(): + if dir_name.startswith(("online-", "express-")): + data_raw_dir.append(path.joinpath(self._config["dataframe"]["daq"])) + elif dir_name == self._config["dataframe"]["daq"].upper(): data_raw_dir.append(path) if not data_raw_dir: @@ -108,25 +102,39 @@ def initialize_paths(self) -> tuple[list[Path], Path]: data_parquet_dir.mkdir(parents=True, exist_ok=True) - return data_raw_dir, data_parquet_dir + self.raw_dir = str(data_raw_dir[0].resolve()) + self.parquet_dir = str(data_parquet_dir) + + @property + def available_runs(self) -> list[int]: + # Get all files in raw_dir with "run" in their names + files = list(Path(self.raw_dir).glob("*run*")) + + # Extract run IDs from filenames + run_ids = set() + for file in files: + match = re.search(r"run(\d+)", file.name) + if match: + run_ids.add(int(match.group(1))) - def get_files_from_run_id( + # Return run IDs in sorted order + return sorted(list(run_ids)) + + def get_files_from_run_id( # type: ignore[override] self, - run_id: str, + run_id: str | int, folders: str | Sequence[str] = None, extension: str = "h5", - **kwds, ) -> list[str]: - """Returns a list of filenames for a given run located in the specified directory + """ + Returns a list of filenames for a given run located in the specified directory for the specified data acquisition (daq). Args: - run_id (str): The run identifier to locate. + run_id (str | int): The run identifier to locate. folders (str | Sequence[str], optional): The directory(ies) where the raw data is located. Defaults to config["core"]["base_folder"]. extension (str, optional): The file extension. Defaults to "h5". - kwds: Keyword arguments: - - daq (str): The data acquisition identifier. Returns: list[str]: A list of path strings representing the collected file names. @@ -143,7 +151,7 @@ def get_files_from_run_id( if isinstance(folders, str): folders = [folders] - daq = kwds.pop("daq", self._config.get("dataframe", {}).get("daq")) + daq = self._config["dataframe"].get("daq") # Generate the file patterns to search for in the directory file_pattern = f"{stream_name_prefixes[daq]}_run{run_id}_*." + extension @@ -167,711 +175,99 @@ def get_files_from_run_id( # Return the list of found files return [str(file.resolve()) for file in files] - @property - def available_channels(self) -> list: - """Returns the channel names that are available for use, - excluding pulseId, defined by the json file""" - available_channels = list(self._config["dataframe"]["channels"].keys()) - available_channels.remove("pulseId") - return available_channels - - def get_channels(self, formats: str | list[str] = "", index: bool = False) -> list[str]: - """ - Returns a list of channels associated with the specified format(s). - - Args: - formats (str | list[str]): The desired format(s) - ('per_pulse', 'per_electron', 'per_train', 'all'). - index (bool): If True, includes channels from the multi_index. - - Returns: - list[str]: A list of channels with the specified format(s). - """ - # If 'formats' is a single string, convert it to a list for uniform processing. - if isinstance(formats, str): - formats = [formats] - - # If 'formats' is a string "all", gather all possible formats. - if formats == ["all"]: - channels = self.get_channels(["per_pulse", "per_train", "per_electron"], index) - return channels - - channels = [] - for format_ in formats: - # Gather channels based on the specified format(s). - channels.extend( - key - for key in self.available_channels - if self._config["dataframe"]["channels"][key]["format"] == format_ - and key != "dldAux" - ) - # Include 'dldAuxChannels' if the format is 'per_pulse'. - if format_ == "per_pulse": - channels.extend( - self._config["dataframe"]["channels"]["dldAux"]["dldAuxChannels"].keys(), - ) - - # Include channels from multi_index if 'index' is True. - if index: - channels.extend(self.multi_index) - - return channels - - def reset_multi_index(self) -> None: - """Resets the index per pulse and electron""" - self.index_per_electron = None - self.index_per_pulse = None - - def create_multi_index_per_electron(self, h5_file: h5py.File) -> None: - """ - Creates an index per electron using pulseId for usage with the electron - resolved pandas DataFrame. - - Args: - h5_file (h5py.File): The HDF5 file object. - - Notes: - - This method relies on the 'pulseId' channel to determine - the macrobunch IDs. - - It creates a MultiIndex with trainId, pulseId, and electronId - as the index levels. - """ - - # Macrobunch IDs obtained from the pulseId channel - [train_id, np_array] = self.create_numpy_array_per_channel( - h5_file, - "pulseId", - ) - - # Create a series with the macrobunches as index and - # microbunches as values - macrobunches = ( - Series( - (np_array[i] for i in train_id.index), - name="pulseId", - index=train_id, - ) - - self._config["dataframe"]["ubid_offset"] - ) - - # Explode dataframe to get all microbunch vales per macrobunch, - # remove NaN values and convert to type int - microbunches = macrobunches.explode().dropna().astype(int) - - # Create temporary index values - index_temp = MultiIndex.from_arrays( - (microbunches.index, microbunches.values), - names=["trainId", "pulseId"], - ) - - # Calculate the electron counts per pulseId unique preserves the order of appearance - electron_counts = index_temp.value_counts()[index_temp.unique()].values - - # Series object for indexing with electrons - electrons = ( - Series( - [np.arange(electron_counts[i]) for i in range(electron_counts.size)], - ) - .explode() - .astype(int) - ) - - # Create a pandas MultiIndex using the exploded datasets - self.index_per_electron = MultiIndex.from_arrays( - (microbunches.index, microbunches.values, electrons), - names=self.multi_index, - ) - - def create_multi_index_per_pulse( - self, - train_id: Series, - np_array: np.ndarray, - ) -> None: - """ - Creates an index per pulse using a pulse resolved channel's macrobunch ID, for usage with - the pulse resolved pandas DataFrame. - - Args: - train_id (Series): The train ID Series. - np_array (np.ndarray): The numpy array containing the pulse resolved data. - - Notes: - - This method creates a MultiIndex with trainId and pulseId as the index levels. - """ - - # Create a pandas MultiIndex, useful for comparing electron and - # pulse resolved dataframes - self.index_per_pulse = MultiIndex.from_product( - (train_id, np.arange(0, np_array.shape[1])), - names=["trainId", "pulseId"], - ) - - def create_numpy_array_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> tuple[Series, np.ndarray]: - """ - Returns a numpy array for a given channel name for a given file. - - Args: - h5_file (h5py.File): The h5py file object. - channel (str): The name of the channel. - - Returns: - tuple[Series, np.ndarray]: A tuple containing the train ID Series and the numpy array - for the channel's data. - - """ - # Get the data from the necessary h5 file and channel - group = h5_file[self._config["dataframe"]["channels"][channel]["group_name"]] - - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - train_id = Series(group["index"], name="trainId") # macrobunch - - # unpacks the timeStamp or value - if channel == "timeStamp": - np_array = group["time"][()] - else: - np_array = group["value"][()] - - # Use predefined axis and slice from the json file - # to choose correct dimension for necessary channel - if "slice" in channel_dict: - np_array = np.take( - np_array, - channel_dict["slice"], - axis=1, - ) - return train_id, np_array - - def create_dataframe_per_electron( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per electron]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - The microbunch resolved data is exploded and converted to a DataFrame. The MultiIndex - is set, and the NaN values are dropped, alongside the pulseId = 0 (meaningless). - - """ - return ( - Series((np_array[i] for i in train_id.index), name=channel) - .explode() - .dropna() - .to_frame() - .set_index(self.index_per_electron) - .drop( - index=np.arange(-self._config["dataframe"]["ubid_offset"], 0), - level=1, - errors="ignore", - ) - ) - - def create_dataframe_per_pulse( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - channel_dict: dict, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per pulse]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - channel_dict (dict): The dictionary containing channel parameters. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - - For auxillary channels, the macrobunch resolved data is repeated 499 times to be - compared to electron resolved data for each auxillary channel. The data is then - converted to a multicolumn DataFrame. - - For all other pulse resolved channels, the macrobunch resolved data is exploded - to a DataFrame and the MultiIndex is set. - - """ - - # Special case for auxillary channels - if channel == "dldAux": - # Checks the channel dictionary for correct slices and creates a multicolumn DataFrame - data_frames = ( - Series( - (np_array[i, value] for i in train_id.index), - name=key, - index=train_id, - ).to_frame() - for key, value in channel_dict["dldAuxChannels"].items() - ) - - # Multiindex set and combined dataframe returned - data = reduce(DataFrame.combine_first, data_frames) - - # For all other pulse resolved channels - else: - # Macrobunch resolved data is exploded to a DataFrame and the MultiIndex is set - - # Creates the index_per_pulse for the given channel - self.create_multi_index_per_pulse(train_id, np_array) - data = ( - Series((np_array[i] for i in train_id.index), name=channel) - .explode() - .to_frame() - .set_index(self.index_per_pulse) - ) - - return data - - def create_dataframe_per_train( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per train]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - """ - return ( - Series((np_array[i] for i in train_id.index), name=channel) - .to_frame() - .set_index(train_id) - ) - - def create_dataframe_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> Series | DataFrame: - """ - Returns a pandas DataFrame for a given channel name from a given file. - - This method takes an h5py.File object `h5_file` and a channel name `channel`, and returns - a pandas DataFrame containing the data for that channel from the file. The format of the - DataFrame depends on the channel's format specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. - channel (str): The name of the channel. - - Returns: - Series | DataFrame: A pandas Series or DataFrame representing the channel's data. - - Raises: - ValueError: If the channel has an undefined format. - - """ - [train_id, np_array] = self.create_numpy_array_per_channel( - h5_file, - channel, - ) # numpy Array created - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - # If np_array is size zero, fill with NaNs - if np_array.size == 0: - # Fill the np_array with NaN values of the same shape as train_id - np_array = np.full_like(train_id, np.nan, dtype=np.double) - # Create a Series using np_array, with train_id as the index - data = Series( - (np_array[i] for i in train_id.index), - name=channel, - index=train_id, - ) - - # Electron resolved data is treated here - if channel_dict["format"] == "per_electron": - # If index_per_electron is None, create it for the given file - if self.index_per_electron is None: - self.create_multi_index_per_electron(h5_file) - - # Create a DataFrame for electron-resolved data - data = self.create_dataframe_per_electron( - np_array, - train_id, - channel, - ) - - # Pulse resolved data is treated here - elif channel_dict["format"] == "per_pulse": - # Create a DataFrame for pulse-resolved data - data = self.create_dataframe_per_pulse( - np_array, - train_id, - channel, - channel_dict, - ) - - # Train resolved data is treated here - elif channel_dict["format"] == "per_train": - # Create a DataFrame for train-resolved data - data = self.create_dataframe_per_train(np_array, train_id, channel) - - else: - raise ValueError( - channel - + "has an undefined format. Available formats are \ - per_pulse, per_electron and per_train", - ) - - return data - - def concatenate_channels( - self, - h5_file: h5py.File, - ) -> DataFrame: - """ - Concatenates the channels from the provided h5py.File into a pandas DataFrame. - - This method takes an h5py.File object `h5_file` and concatenates the channels present in - the file into a single pandas DataFrame. The concatenation is performed based on the - available channels specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. + def parse_metadata(self, scicat_token: str = None) -> dict: + """Uses the MetadataRetriever class to fetch metadata from scicat for each run. Returns: - DataFrame: A concatenated pandas DataFrame containing the channels. - - Raises: - ValueError: If the group_name for any channel does not exist in the file. - + dict: Metadata dictionary + scicat_token (str, optional):: The scicat token to use for fetching metadata """ - all_keys = parse_h5_keys(h5_file) # Parses all channels present - - # Check for if the provided group_name actually exists in the file - for channel in self._config["dataframe"]["channels"]: - if channel == "timeStamp": - group_name = self._config["dataframe"]["channels"][channel]["group_name"] + "time" - else: - group_name = self._config["dataframe"]["channels"][channel]["group_name"] + "value" - - if group_name not in all_keys: - raise ValueError( - f"The group_name for channel {channel} does not exist.", - ) - - # Create a generator expression to generate data frames for each channel - data_frames = ( - self.create_dataframe_per_channel(h5_file, each) for each in self.available_channels + metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token) + metadata = metadata_retriever.get_metadata( + beamtime_id=self._config["core"]["beamtime_id"], + runs=self.runs, + metadata=self.metadata, ) - # Use the reduce function to join the data frames into a single DataFrame - return reduce( - lambda left, right: left.join(right, how="outer"), - data_frames, - ) + return metadata - def create_dataframe_per_file( + def get_count_rate( self, - file_path: Path, - ) -> DataFrame: - """ - Create pandas DataFrames for the given file. - - This method loads an HDF5 file specified by `file_path` and constructs a pandas DataFrame - from the datasets within the file. The order of datasets in the DataFrames is the opposite - of the order specified by channel names. - - Args: - file_path (Path): Path to the input HDF5 file. - - Returns: - DataFrame: pandas DataFrame + fids: Sequence[int] = None, # noqa: ARG002 + **kwds, # noqa: ARG002 + ): + return None, None + def get_elapsed_time(self, fids: Sequence[int] = None, **kwds) -> float | list[float]: # type: ignore[override] """ - # Loads h5 file and creates a dataframe - with h5py.File(file_path, "r") as h5_file: - self.reset_multi_index() # Reset MultiIndexes for next file - df = self.concatenate_channels(h5_file) - df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps")) - # correct the 3 bit shift which encodes the detector ID in the 8s time - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - df = split_dld_time_from_sector_id(df, config=self._config) - return df - - def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> bool | Exception: - """ - Converts an HDF5 file to Parquet format to create a buffer file. - - This method uses `create_dataframe_per_file` method to create dataframes from individual - files within an HDF5 file. The resulting dataframe is then saved to a Parquet file. + Calculates the elapsed time. Args: - h5_path (Path): Path to the input HDF5 file. - parquet_path (Path): Path to the output Parquet file. + fids (Sequence[int]): A sequence of file IDs. Defaults to all files. + **kwds: + - runs: A sequence of run IDs. Takes precedence over fids. + - aggregate: Whether to return the sum of the elapsed times across + the specified files or the elapsed time for each file. Defaults to True. Returns: - bool | Exception: Collected exceptions, if any. + float | list[float]: The elapsed time(s) in seconds. Raises: - ValueError: If an error occurs during the conversion process. - + KeyError: If a file ID in fids or a run ID in 'runs' does not exist in the metadata. """ try: - ( - self.create_dataframe_per_file(h5_path) - .reset_index(level=self.multi_index) - .to_parquet(parquet_path, index=False) - ) - except Exception as exc: # pylint: disable=broad-except - self.failed_files_error.append(f"{parquet_path}: {type(exc)} {exc}") - return exc - return None - - def buffer_file_handler( - self, - data_parquet_dir: Path, - detector: str, - force_recreate: bool, - ) -> tuple[list[Path], list, list]: - """ - Handles the conversion of buffer files (h5 to parquet) and returns the filenames. - - Args: - data_parquet_dir (Path): Directory where the parquet files will be stored. - detector (str): Detector name. - force_recreate (bool): Forces recreation of buffer files - - Returns: - tuple[list[Path], list, list]: Three lists, one for - parquet file paths, one for metadata and one for schema. - - Raises: - FileNotFoundError: If the conversion fails for any files or no data is available. - """ - - # Create the directory for buffer parquet files - buffer_file_dir = data_parquet_dir.joinpath("buffer") - buffer_file_dir.mkdir(parents=True, exist_ok=True) - - # Create two separate lists for h5 and parquet file paths - h5_filenames = [Path(file) for file in self.files] - parquet_filenames = [ - buffer_file_dir.joinpath(Path(file).stem + detector) for file in self.files - ] - existing_parquet_filenames = [file for file in parquet_filenames if file.exists()] - - # Raise a value error if no data is available after the conversion - if len(h5_filenames) == 0: - raise ValueError("No data available. Probably failed reading all h5 files") - - if not force_recreate: - # Check if the available channels match the schema of the existing parquet files - parquet_schemas = [pq.read_schema(file) for file in existing_parquet_filenames] - config_schema = set(self.get_channels(formats="all", index=True)) - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - config_schema.add(self._config["dataframe"].get("sector_id_column", False)) - - for i, schema in enumerate(parquet_schemas): - schema_set = set(schema.names) - if schema_set != config_schema: - missing_in_parquet = config_schema - schema_set - missing_in_config = schema_set - config_schema - - missing_in_parquet_str = ( - f"Missing in parquet: {missing_in_parquet}" if missing_in_parquet else "" - ) - missing_in_config_str = ( - f"Missing in config: {missing_in_config}" if missing_in_config else "" - ) - - raise ValueError( - "The available channels do not match the schema of file", - f"{existing_parquet_filenames[i]}", - f"{missing_in_parquet_str}", - f"{missing_in_config_str}", - "Please check the configuration file or set force_recreate to True.", - ) - - # Choose files to read - files_to_read = [ - (h5_path, parquet_path) - for h5_path, parquet_path in zip(h5_filenames, parquet_filenames) - if force_recreate or not parquet_path.exists() - ] - - print(f"Reading files: {len(files_to_read)} new files of {len(h5_filenames)} total.") - - # Initialize the indices for create_buffer_file conversion - self.reset_multi_index() - - # Convert the remaining h5 files to parquet in parallel if there are any - if len(files_to_read) > 0: - error = Parallel(n_jobs=len(files_to_read), verbose=10)( - delayed(self.create_buffer_file)(h5_path, parquet_path) - for h5_path, parquet_path in files_to_read - ) - if any(error): - raise RuntimeError(f"Conversion failed for some files. {error}") - - # Raise an error if the conversion failed for any files - # TODO: merge this and the previous error trackings - if self.failed_files_error: - raise FileNotFoundError( - "Conversion failed for the following files:\n" + "\n".join(self.failed_files_error), - ) + file_statistics = self.metadata["file_statistics"] + except Exception as exc: + raise KeyError( + "File statistics missing. Use 'read_dataframe' first.", + ) from exc - print("All files converted successfully!") - - # read all parquet metadata and schema - metadata = [pq.read_metadata(file) for file in parquet_filenames] - schema = [pq.read_schema(file) for file in parquet_filenames] - - return parquet_filenames, metadata, schema - - def parquet_handler( - self, - data_parquet_dir: Path, - detector: str = "", - parquet_path: Path = None, - converted: bool = False, - load_parquet: bool = False, - save_parquet: bool = False, - force_recreate: bool = False, - ) -> tuple[dd.DataFrame, dd.DataFrame]: - """ - Handles loading and saving of parquet files based on the provided parameters. - - Args: - data_parquet_dir (Path): Directory where the parquet files are located. - detector (str, optional): Adds a identifier for parquets to distinguish multidetector - systems. - parquet_path (str, optional): Path to the combined parquet file. - converted (bool, optional): True if data is augmented by adding additional columns - externally and saved into converted folder. - load_parquet (bool, optional): Loads the entire parquet into the dd dataframe. - save_parquet (bool, optional): Saves the entire dataframe into a parquet. - force_recreate (bool, optional): Forces recreation of buffer file. - Returns: - tuple[dd.DataFrame, dd.DataFrame]: A tuple containing two dataframes: - - dataframe_electron: Dataframe containing the loaded/augmented electron data. - - dataframe_pulse: Dataframe containing the loaded/augmented timed data. - - Raises: - FileNotFoundError: If the requested parquet file is not found. - - """ - - # Construct the parquet path if not provided - if parquet_path is None: - parquet_name = "_".join(str(run) for run in self.runs) - parquet_dir = data_parquet_dir.joinpath("converted") if converted else data_parquet_dir - - parquet_path = parquet_dir.joinpath( - "run_" + parquet_name + detector, - ).with_suffix(".parquet") - - # Check if load_parquet is flagged and then load the file if it exists - if load_parquet: + def get_elapsed_time_from_fid(fid): try: - dataframe_electron = dd.read_parquet(parquet_path) - dataframe_pulse = dataframe_electron - except Exception as exc: - raise FileNotFoundError( - "The final parquet for this run(s) does not exist yet. " - "If it is in another location, please provide the path as parquet_path.", + fid = str(fid) # Ensure the key is a string + time_stamps = file_statistics[fid]["time_stamps"] + elapsed_time = max(time_stamps) - min(time_stamps) + except KeyError as exc: + raise KeyError( + f"Timestamp metadata missing in file {fid}." + "Add timestamp column and alias to config before loading.", ) from exc - else: - # Obtain the parquet filenames, metadata and schema from the method - # which handles buffer file creation/reading - filenames, metadata, _ = self.buffer_file_handler( - data_parquet_dir, - detector, - force_recreate, - ) - - # Read all parquet files into one dataframe using dask - dataframe = dd.read_parquet(filenames, calculate_divisions=True) - - # Channels to fill NaN values - channels: list[str] = self.get_channels(["per_pulse", "per_train"]) + return elapsed_time - overlap = min(file.num_rows for file in metadata) + def get_elapsed_time_from_run(run_id): + if self.raw_dir is None: + self._initialize_dirs() + files = self.get_files_from_run_id(run_id=run_id, folders=self.raw_dir) + fids = [self.files.index(file) for file in files] + return sum(get_elapsed_time_from_fid(fid) for fid in fids) - print("Filling nan values...") - dataframe = dfops.forward_fill_lazy( - df=dataframe, - columns=channels, - before=overlap, - iterations=self._config["dataframe"].get("forward_fill_iterations", 2), - ) - # Remove the NaNs from per_electron channels - dataframe_electron = dataframe.dropna( - subset=self.get_channels(["per_electron"]), - ) - dataframe_pulse = dataframe[ - self.multi_index + self.get_channels(["per_pulse", "per_train"]) - ] - dataframe_pulse = dataframe_pulse[ - (dataframe_pulse["electronId"] == 0) | (np.isnan(dataframe_pulse["electronId"])) - ] - - # Save the dataframe as parquet if requested - if save_parquet: - dataframe_electron.compute().reset_index(drop=True).to_parquet(parquet_path) - print("Combined parquet file saved.") - - return dataframe_electron, dataframe_pulse - - def parse_metadata(self, scicat_token: str = None) -> dict: - """Uses the MetadataRetriever class to fetch metadata from scicat for each run. - - Returns: - dict: Metadata dictionary - scicat_token (str, optional):: The scicat token to use for fetching metadata - """ - metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token) - metadata = metadata_retriever.get_metadata( - beamtime_id=self._config["core"]["beamtime_id"], - runs=self.runs, - metadata=self.metadata, - ) - - return metadata + elapsed_times = [] + runs = kwds.get("runs") + if runs is not None: + elapsed_times = [get_elapsed_time_from_run(run) for run in runs] + else: + if fids is None: + fids = range(len(self.files)) + elapsed_times = [get_elapsed_time_from_fid(fid) for fid in fids] - def get_count_rate( - self, - fids: Sequence[int] = None, # noqa: ARG002 - **kwds, # noqa: ARG002 - ): - return None, None + if kwds.get("aggregate", True): + elapsed_times = sum(elapsed_times) - def get_elapsed_time(self, fids=None, **kwds): # noqa: ARG002 - return None + return elapsed_times def read_dataframe( self, files: str | Sequence[str] = None, folders: str | Sequence[str] = None, - runs: str | Sequence[str] = None, + runs: str | int | Sequence[str | int] = None, ftype: str = "h5", - metadata: dict = None, + metadata: dict = {}, collect_metadata: bool = False, + detector: str = "", + force_recreate: bool = False, + parquet_dir: str | Path = None, + debug: bool = False, **kwds, ) -> tuple[dd.DataFrame, dd.DataFrame, dict]: """ @@ -882,9 +278,9 @@ def read_dataframe( folders (str | Sequence[str], optional): Path to folder(s) where files are stored Path has priority such that if it's specified, the specified files will be ignored. Defaults to None. - runs (str | Sequence[str], optional): Run identifier(s). Corresponding files will - be located in the location provided by ``folders``. Takes precedence over - ``files`` and ``folders``. Defaults to None. + runs (str | int | Sequence[str | int], optional): Run identifier(s). + Corresponding files will be located in the location provided by ``folders``. + Takes precedence over ``files`` and ``folders``. Defaults to None. ftype (str, optional): The file extension type. Defaults to "h5". metadata (dict, optional): Additional metadata. Defaults to None. collect_metadata (bool, optional): Whether to collect metadata. Defaults to False. @@ -899,24 +295,19 @@ def read_dataframe( """ t0 = time.time() - data_raw_dir, data_parquet_dir = self.initialize_paths() - + self._initialize_dirs() # Prepare a list of names for the runs to read and parquets to write if runs is not None: files = [] - if isinstance(runs, (str, int)): - runs = [runs] - for run in runs: + runs_ = [str(runs)] if isinstance(runs, (str, int)) else list(map(str, runs)) + for run in runs_: run_files = self.get_files_from_run_id( run_id=run, - folders=[str(folder.resolve()) for folder in data_raw_dir], - extension=ftype, - daq=self._config["dataframe"]["daq"], + folders=self.raw_dir, ) files.extend(run_files) - self.runs = list(runs) + self.runs = runs_ super().read_dataframe(files=files, ftype=ftype) - else: # This call takes care of files and folders. As we have converted runs into files # already, they are just stored in the class by this call. @@ -927,12 +318,36 @@ def read_dataframe( metadata=metadata, ) - df, df_timed = self.parquet_handler(data_parquet_dir, **kwds) + bh = BufferHandler( + config=self._config, + ) + + # if parquet_dir is None, use self.parquet_dir + parquet_dir = parquet_dir or self.parquet_dir + parquet_dir = Path(parquet_dir) + + # Obtain the parquet filenames, metadata, and schema from the method + # which handles buffer file creation/reading + h5_paths = [Path(file) for file in self.files] + bh.run( + h5_paths=h5_paths, + folder=parquet_dir, + force_recreate=force_recreate, + suffix=detector, + debug=debug, + ) + df = bh.df_electron + df_timed = bh.df_pulse + + if self.instrument == "wespe": + df, df_timed = wespe_convert(df, df_timed) + + self.metadata.update(self.parse_metadata(**kwds) if collect_metadata else {}) + self.metadata.update(bh.metadata) - metadata = self.parse_metadata(**kwds) if collect_metadata else {} print(f"loading complete in {time.time() - t0: .2f} s") - return df, df_timed, metadata + return df, df_timed, self.metadata LOADER = FlashLoader diff --git a/sed/loader/flash/utils.py b/sed/loader/flash/utils.py new file mode 100644 index 00000000..76af4150 --- /dev/null +++ b/sed/loader/flash/utils.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from pathlib import Path + +# TODO: move to config +MULTI_INDEX = ["trainId", "pulseId", "electronId"] +PULSE_ALIAS = MULTI_INDEX[1] +DLD_AUX_ALIAS = "dldAux" +DLDAUX_CHANNELS = "dldAuxChannels" +FORMATS = ["per_electron", "per_pulse", "per_train"] + + +def get_channels( + channel_dict: dict = None, + formats: str | list[str] = None, + index: bool = False, + extend_aux: bool = False, +) -> list[str]: + """ + Returns a list of channels associated with the specified format(s). + + Args: + formats (str | list[str]): The desired format(s) + ('per_pulse', 'per_electron', 'per_train', 'all'). + index (bool): If True, includes channels from the multiindex. + extend_aux (bool): If True, includes channels from the 'dldAuxChannels' dictionary, + else includes 'dldAux'. + + Returns: + List[str]: A list of channels with the specified format(s). + """ + # If 'formats' is a single string, convert it to a list for uniform processing. + if isinstance(formats, str): + formats = [formats] + + # If 'formats' is a string "all", gather all possible formats. + if formats == ["all"]: + channels = get_channels( + channel_dict, + FORMATS, + index, + extend_aux, + ) + return channels + + channels = [] + + # Include channels from multi_index if 'index' is True. + if index: + channels.extend(MULTI_INDEX) + + if formats: + # If 'formats' is a list, check if all elements are valid. + for format_ in formats: + if format_ not in FORMATS + ["all"]: + raise ValueError( + "Invalid format. Please choose from 'per_electron', 'per_pulse',\ + 'per_train', 'all'.", + ) + + # Get the available channels excluding 'pulseId'. + available_channels = list(channel_dict.keys()) + # raises error if not available, but necessary for pulse_index + available_channels.remove(PULSE_ALIAS) + + for format_ in formats: + # Gather channels based on the specified format(s). + channels.extend( + key + for key in available_channels + if channel_dict[key]["format"] == format_ and key != DLD_AUX_ALIAS + ) + # Include 'dldAuxChannels' if the format is 'per_pulse' and extend_aux is True. + # Otherwise, include 'dldAux'. + if format_ == FORMATS[2] and DLD_AUX_ALIAS in available_channels: + if extend_aux: + channels.extend( + channel_dict[DLD_AUX_ALIAS][DLDAUX_CHANNELS].keys(), + ) + else: + channels.extend([DLD_AUX_ALIAS]) + + return channels + + +def initialize_paths( + filenames: str | list[str] = None, + folder: Path = None, + subfolder: str = "", + prefix: str = "", + suffix: str = "", + extension: str = "parquet", + paths: list[Path] = None, +) -> list[Path]: + """ + Initialize the paths for files to be saved/loaded. + + If custom paths are provided, they will be used. Otherwise, paths will be generated based on + the specified parameters during initialization. + + Args: + filenames (str | list[str]): The name(s) of the file(s). + folder (Path): The folder where the files are saved. + subfolder (str): The subfolder where the files are saved. + prefix (str): The prefix for the file name. + suffix (str): The suffix for the file name. + extension (str): The extension for the file. + paths (list[Path]): Custom paths for the files. + + Returns: + list[Path]: The paths for the files. + """ + # if filenames is string, convert it to a list + if isinstance(filenames, str): + filenames = [filenames] + + # Check if the folder and Parquet paths are provided + if not folder and not paths: + raise ValueError("Please provide folder or paths.") + if folder and not filenames: + raise ValueError("With folder, please provide filenames.") + + # Otherwise create the full path for the Parquet file + directory = folder.joinpath(subfolder) + directory.mkdir(parents=True, exist_ok=True) + + if extension: + extension = f".{extension}" # if extension is provided, it is prepended with a dot + if prefix: + prefix = f"{prefix}_" + if suffix: + suffix = f"_{suffix}" + paths = [directory.joinpath(Path(f"{prefix}{name}{suffix}{extension}")) for name in filenames] + + return paths diff --git a/sed/loader/sxp/loader.py b/sed/loader/sxp/loader.py index 96b1f7be..0ec40b89 100644 --- a/sed/loader/sxp/loader.py +++ b/sed/loader/sxp/loader.py @@ -51,15 +51,13 @@ def __init__(self, config: dict) -> None: self.index_per_pulse: MultiIndex = None self.failed_files_error: list[str] = [] self.array_indices: list[list[slice]] = None + self.raw_dir: str = None + self.parquet_dir: str = None - def initialize_paths(self) -> tuple[list[Path], Path]: + def _initialize_dirs(self): """ Initializes the paths based on the configuration. - Returns: - tuple[List[Path], Path]: A tuple containing a list of raw data directories - paths and the parquet data directory path. - Raises: ValueError: If required values are missing from the configuration. FileNotFoundError: If the raw data directories are not found. @@ -101,7 +99,8 @@ def initialize_paths(self) -> tuple[list[Path], Path]: data_parquet_dir.mkdir(parents=True, exist_ok=True) - return data_raw_dir, data_parquet_dir + self.raw_dir = data_raw_dir + self.parquet_dir = data_parquet_dir def get_files_from_run_id( self, @@ -654,7 +653,7 @@ def create_dataframe_per_file( df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps")) # correct the 3 bit shift which encodes the detector ID in the 8s time if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - df = split_dld_time_from_sector_id(df, config=self._config) + df, _ = split_dld_time_from_sector_id(df, config=self._config) return df def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> bool | Exception: @@ -943,7 +942,7 @@ def read_dataframe( """ t0 = time.time() - data_raw_dir, data_parquet_dir = self.initialize_paths() + self._initialize_dirs() # Prepare a list of names for the runs to read and parquets to write if runs is not None: @@ -953,7 +952,7 @@ def read_dataframe( for run in runs: run_files = self.get_files_from_run_id( run_id=run, - folders=[str(folder.resolve()) for folder in data_raw_dir], + folders=[str(Path(folder).resolve()) for folder in self.raw_dir], extension=ftype, daq=self._config["dataframe"]["daq"], ) @@ -971,7 +970,7 @@ def read_dataframe( metadata=metadata, ) - df, df_timed = self.parquet_handler(data_parquet_dir, **kwds) + df, df_timed = self.parquet_handler(Path(self.parquet_dir), **kwds) if collect_metadata: metadata = self.gather_metadata( diff --git a/sed/loader/utils.py b/sed/loader/utils.py index ba3778df..f47419c0 100644 --- a/sed/loader/utils.py +++ b/sed/loader/utils.py @@ -4,11 +4,13 @@ from collections.abc import Sequence from glob import glob +from pathlib import Path from typing import cast import dask.dataframe import numpy as np import pandas as pd +import pyarrow.parquet as pq from h5py import File from h5py import Group from natsort import natsorted @@ -149,7 +151,7 @@ def split_dld_time_from_sector_id( sector_id_column: str = None, sector_id_reserved_bits: int = None, config: dict = None, -) -> pd.DataFrame | dask.dataframe.DataFrame: +) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]: """Converts the 8s time in steps to time in steps and sectorID. The 8s detector encodes the dldSectorID in the 3 least significant bits of the @@ -162,7 +164,7 @@ def split_dld_time_from_sector_id( sector_id_column (str, optional): Name of the column containing the sectorID. Defaults to config["dataframe"]["sector_id_column"]. sector_id_reserved_bits (int, optional): Number of bits reserved for the - config (dict, optional): Configuration dictionary. Defaults to None. + config (dict, optional): Dataframe configuration dictionary. Defaults to None. Returns: pd.DataFrame | dask.dataframe.DataFrame: Dataframe with the new columns. @@ -170,28 +172,95 @@ def split_dld_time_from_sector_id( if tof_column is None: if config is None: raise ValueError("Either tof_column or config must be given.") - tof_column = config["dataframe"]["tof_column"] + tof_column = config["tof_column"] if sector_id_column is None: if config is None: raise ValueError("Either sector_id_column or config must be given.") - sector_id_column = config["dataframe"]["sector_id_column"] + sector_id_column = config["sector_id_column"] if sector_id_reserved_bits is None: if config is None: raise ValueError("Either sector_id_reserved_bits or config must be given.") - sector_id_reserved_bits = config["dataframe"].get("sector_id_reserved_bits", None) + sector_id_reserved_bits = config.get("sector_id_reserved_bits", None) if sector_id_reserved_bits is None: raise ValueError('No value for "sector_id_reserved_bits" found in config.') if sector_id_column in df.columns: - raise ValueError( - f"Column {sector_id_column} already in dataframe. This function is not idempotent.", + metadata = {"applied": False, "reason": f"Column {sector_id_column} already in dataframe"} + else: + # Split the time-of-flight column into sector ID and time-of-flight steps + df = split_channel_bitwise( + df=df, + input_column=tof_column, + output_columns=[sector_id_column, tof_column], + bit_mask=sector_id_reserved_bits, + overwrite=True, + types=[np.int8, np.int32], ) - df = split_channel_bitwise( - df=df, - input_column=tof_column, - output_columns=[sector_id_column, tof_column], - bit_mask=sector_id_reserved_bits, - overwrite=True, - types=[np.int8, np.int32], - ) - return df + metadata = { + "applied": True, + "tof_column": tof_column, + "sector_id_column": sector_id_column, + "sector_id_reserved_bits": sector_id_reserved_bits, + } + + return df, {"split_dld_time_from_sector_id": metadata} + + +def get_timestamp_stats(meta: pq.FileMetaData, time_stamp_col: str) -> tuple[int, int]: + """ + Extracts the minimum and maximum timestamps from the metadata of a Parquet file. + + Args: + meta (pq.FileMetaData): The metadata of the Parquet file. + time_stamp_col (str): The name of the column containing the timestamps. + + Returns: + Tuple[int, int]: The minimum and maximum timestamps. + """ + idx = meta.schema.names.index(time_stamp_col) + timestamps = [] + for i in range(meta.num_row_groups): + stats = meta.row_group(i).column(idx).statistics + timestamps.append(stats.min) + timestamps.append(stats.max) + + return min(timestamps), max(timestamps) + + +def get_parquet_metadata(file_paths: list[Path], time_stamp_col: str) -> dict[str, dict]: + """ + Extracts and organizes metadata from a list of Parquet files. + + For each file, the function reads the metadata, adds the filename, and attempts to + extract the minimum and maximum timestamps. "row_groups" entry is removed from FileMetaData. + + Args: + file_paths (list[Path]): A list of paths to the Parquet files. + time_stamp_col (str): The name of the column containing the timestamps. + + Returns: + dict[str, dict]: A dictionary file index as key and the values as metadata of each file. + """ + organized_metadata = {} + for i, file_path in enumerate(file_paths): + # Read the metadata for the file + file_meta: pq.FileMetaData = pq.read_metadata(file_path) + # Convert the metadata to a dictionary + metadata_dict = file_meta.to_dict() + # Add the filename to the metadata dictionary + metadata_dict["filename"] = str(file_path.name) + + # Get the timestamp min and max + try: + timestamps = get_timestamp_stats(file_meta, time_stamp_col) + metadata_dict["time_stamps"] = timestamps + except ValueError: + pass + + # Remove "row_groups" as they contain a lot of info that is not needed + metadata_dict.pop("row_groups", None) + + # Add the metadata dictionary to the organized_metadata dictionary + organized_metadata[str(i)] = metadata_dict + + return organized_metadata diff --git a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 index 02a04d9f..1eafbf33 100644 Binary files a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 and b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 differ diff --git a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 new file mode 100644 index 00000000..1524e5ad Binary files /dev/null and b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 differ diff --git a/tests/data/loader/flash/config.yaml b/tests/data/loader/flash/config.yaml index b7049dfb..652e875c 100644 --- a/tests/data/loader/flash/config.yaml +++ b/tests/data/loader/flash/config.yaml @@ -81,29 +81,37 @@ dataframe: # pulse ID is a necessary channel for using the loader. pulseId: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 2 dldPosX: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 1 + dtype: uint16 dldPosY: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 0 + dtype: uint16 dldTimeSteps: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 3 + dtype: uint32 # The auxillary channel has a special structure where the group further contains # a multidim structure so further aliases are defined below dldAux: - format: per_pulse - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + format: per_train + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 4 dldAuxChannels: sampleBias: 0 @@ -116,15 +124,24 @@ dataframe: timeStamp: format: per_train - group_name: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/" + index_key: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/index" + dataset_key: "/uncategorised/FLASH.DIAG/TIMINGINFO/TIME1.BUNCH_FIRST_INDEX.1/time" delayStage: format: per_train - group_name: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/" + index_key: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/index" + dataset_key: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/value" + + pulserSignAdc: + format: per_pulse + index_key: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/index" + dataset_key: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/value" gmdTunnel: format: per_pulse - group_name: "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/" + index_key: "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/index" + dataset_key: "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/value" + slice: 0 # The prefixes of the stream names for different DAQ systems for parsing filenames # (Not to be changed by user) diff --git a/tests/loader/flash/conftest.py b/tests/loader/flash/conftest.py new file mode 100644 index 00000000..7cec83d8 --- /dev/null +++ b/tests/loader/flash/conftest.py @@ -0,0 +1,74 @@ +""" This module contains fixtures for the FEL module tests. +""" +import os +import shutil +from importlib.util import find_spec +from pathlib import Path + +import h5py +import pytest + +from sed.core.config import parse_config + +package_dir = os.path.dirname(find_spec("sed").origin) +config_path = os.path.join(package_dir, "../tests/data/loader/flash/config.yaml") +H5_PATH = "FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5" +H5_PATHS = [H5_PATH, "FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5"] + + +@pytest.fixture(name="config") +def fixture_config_file(): + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return parse_config(config_path) + + +@pytest.fixture(name="config_dataframe") +def fixture_config_file_dataframe(): + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return parse_config(config_path)["dataframe"] + + +@pytest.fixture(name="h5_file") +def fixture_h5_file(): + """Fixture providing an open h5 file. + + Returns: + h5py.File: The open h5 file. + """ + return h5py.File(os.path.join(package_dir, f"../tests/data/loader/flash/{H5_PATH}"), "r") + + +@pytest.fixture(name="h5_file_copy") +def fixture_h5_file_copy(tmp_path): + """Fixture providing a copy of an open h5 file. + + Returns: + h5py.File: The open h5 file copy. + """ + # Create a copy of the h5 file in a temporary directory + original_file_path = os.path.join(package_dir, f"../tests/data/loader/flash/{H5_PATH}") + copy_file_path = tmp_path / "copy.h5" + shutil.copyfile(original_file_path, copy_file_path) + + # Open the copy in 'read-write' mode and return it + return h5py.File(copy_file_path, "r+") + + +@pytest.fixture(name="h5_paths") +def fixture_h5_paths(): + """Fixture providing a list of h5 file paths. + + Returns: + list: A list of h5 file paths. + """ + return [ + Path(os.path.join(package_dir, f"../tests/data/loader/flash/{path}")) for path in H5_PATHS + ] diff --git a/tests/loader/flash/test_buffer_handler.py b/tests/loader/flash/test_buffer_handler.py new file mode 100644 index 00000000..2df7cf17 --- /dev/null +++ b/tests/loader/flash/test_buffer_handler.py @@ -0,0 +1,218 @@ +"""Test cases for the BufferHandler class in the Flash module.""" +from copy import deepcopy +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from sed.loader.flash.buffer_handler import BufferHandler +from sed.loader.flash.utils import get_channels + + +def create_parquet_dir(config, folder): + """ + Creates a directory for storing Parquet files based on the provided configuration + and folder name. + """ + + parquet_path = Path(config["core"]["paths"]["data_parquet_dir"]) + parquet_path = parquet_path.joinpath(folder) + parquet_path.mkdir(parents=True, exist_ok=True) + return parquet_path + + +def test_get_files_to_read(config, h5_paths): + """ + Test the BufferHandler's ability to identify files that need to be read and + manage buffer file paths. + + This test performs several checks to ensure the BufferHandler correctly identifies + which HDF5 files need to be read and properly manages the paths for saving buffer + files. It follows these steps: + 1. Creates a directory structure for storing buffer files and initializes the BufferHandler. + 2. Invokes the private method _get_files_to_read to populate the list of missing HDF5 files and + verify that initially, all provided files are considered missing. + 3. Checks that the paths for saving buffer files are correctly generated. + 4. Creates a single buffer file and reruns the check to ensure that the BufferHandler recognizes + one less missing file. + 5. Cleans up by removing the created buffer file. + 6. Tests the handling of prefix and suffix in buffer file names by rerunning the checks with + modified file name parameters. + """ + folder = create_parquet_dir(config, "get_files_to_read") + subfolder = folder.joinpath("buffer") + # set to false to avoid creating buffer files unnecessarily + bh = BufferHandler(config) + bh._get_files_to_read(h5_paths, folder, "", "", False) + + # check that all files are to be read + assert np.all(bh.missing_h5_files == h5_paths) + + # create expected paths + expected_buffer_paths = [Path(subfolder, f"{Path(path).stem}") for path in h5_paths] + + # check that all buffer paths are correct + assert np.all(bh.save_paths == expected_buffer_paths) + + # create only one buffer file + bh._save_buffer_file(h5_paths[0], expected_buffer_paths[0]) + # check again for files to read + bh._get_files_to_read(h5_paths, folder, "", "", False) + # check that only one file is to be read + assert len(bh.missing_h5_files) == len(h5_paths) - 1 + Path(expected_buffer_paths[0]).unlink() # remove buffer file + + # add prefix and suffix + bh._get_files_to_read(h5_paths, folder, "prefix", "suffix", False) + + # expected buffer paths with prefix and suffix + expected_buffer_paths = [ + Path(subfolder, f"prefix_{Path(path).stem}_suffix") for path in h5_paths + ] + assert np.all(bh.save_paths == expected_buffer_paths) + + +def test_buffer_schema_mismatch(config, h5_paths): + """ + Test function to verify schema mismatch handling in the FlashLoader's 'read_dataframe' method. + + The test validates the error handling mechanism when the available channels do not match the + schema of the existing parquet files. + + Test Steps: + - Attempt to read a dataframe after adding a new channel 'gmdTunnel2' to the configuration. + - Check for an expected error related to the mismatch between available channels and schema. + - Force recreation of dataframe with the added channel, ensuring successful dataframe + creation. + - Simulate a missing channel scenario by removing 'gmdTunnel2' from the configuration. + - Check for an error indicating a missing channel in the configuration. + - Clean up created buffer files after the test. + """ + folder = create_parquet_dir(config, "schema_mismatch") + bh = BufferHandler(config) + bh.run(h5_paths=h5_paths, folder=folder, debug=True) + + # Manipulate the configuration to introduce a new channel 'gmdTunnel2' + config_dict = config + config_dict["dataframe"]["channels"]["gmdTunnel2"] = { + "index_key": "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/index", + "dataset_key": "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/value", + "format": "per_pulse", + "slice": 0, + } + + # Reread the dataframe with the modified configuration, expecting a schema mismatch error + with pytest.raises(ValueError) as e: + bh = BufferHandler(config) + bh.run(h5_paths=h5_paths, folder=folder, debug=True) + expected_error = e.value.args[0] + + # Validate the specific error messages for schema mismatch + assert "The available channels do not match the schema of file" in expected_error + assert "Missing in parquet: {'gmdTunnel2'}" in expected_error + assert "Please check the configuration file or set force_recreate to True." in expected_error + + # Force recreation of the dataframe, including the added channel 'gmdTunnel2' + bh = BufferHandler(config) + bh.run(h5_paths=h5_paths, folder=folder, force_recreate=True, debug=True) + + # Remove 'gmdTunnel2' from the configuration to simulate a missing channel scenario + del config["dataframe"]["channels"]["gmdTunnel2"] + # also results in error but different from before + with pytest.raises(ValueError) as e: + # Attempt to read the dataframe again to check for the missing channel error + bh = BufferHandler(config) + bh.run(h5_paths=h5_paths, folder=folder, debug=True) + + expected_error = e.value.args[0] + # Check for the specific error message indicating a missing channel in the configuration + assert "Missing in config: {'gmdTunnel2'}" in expected_error + + # Clean up created buffer files after the test + [path.unlink() for path in bh.buffer_paths] + + +def test_save_buffer_files(config, h5_paths): + """ + Test the BufferHandler's ability to save buffer files serially and in parallel. + + This test ensures that the BufferHandler can run both serially and in parallel, saving the + output to buffer files, and then it compares the resulting DataFrames to ensure they are + identical. This verifies that parallel processing does not affect the integrity of the data + saved. After the comparison, it cleans up by removing the created buffer files. + """ + folder_serial = create_parquet_dir(config, "save_buffer_files_serial") + bh_serial = BufferHandler(config) + bh_serial.run(h5_paths, folder_serial, debug=True) + + folder_parallel = create_parquet_dir(config, "save_buffer_files_parallel") + bh_parallel = BufferHandler(config) + bh_parallel.run(h5_paths, folder_parallel) + + df_serial = pd.read_parquet(folder_serial) + df_parallel = pd.read_parquet(folder_parallel) + + pd.testing.assert_frame_equal(df_serial, df_parallel) + + # remove buffer files + [path.unlink() for path in bh_serial.buffer_paths] + [path.unlink() for path in bh_parallel.buffer_paths] + + +def test_save_buffer_files_exception(config, h5_paths, h5_file_copy, tmp_path): + """Test function to verify exception handling when running code in parallel.""" + folder_parallel = create_parquet_dir(config, "save_buffer_files_exception") + config_ = deepcopy(config) + + # check exception in case of missing channel in config + channel = "dldPosX" + del config_["dataframe"]["channels"][channel]["index_key"] + + # testing exception in parallel execution + with pytest.raises(ValueError): + bh = BufferHandler(config_) + bh.run(h5_paths, folder_parallel, debug=False) + + # check exception message with empty dataset + config_ = deepcopy(config) + channel = "testChannel" + channel_index_key = "test/dataset/empty/index" + empty_dataset_key = "test/dataset/empty/value" + config_["dataframe"]["channels"][channel] = { + "index_key": channel_index_key, + "dataset_key": empty_dataset_key, + "format": "per_train", + } + + # create an empty dataset + h5_file_copy.create_dataset( + name=empty_dataset_key, + shape=0, + ) + + # expect key error because of missing index dataset + with pytest.raises(KeyError): + bh = BufferHandler(config_) + bh.run([tmp_path / "copy.h5"], folder_parallel, debug=False, force_recreate=True) + + +def test_get_filled_dataframe(config, h5_paths): + """Test function to verify the creation of a filled dataframe from the buffer files.""" + folder = create_parquet_dir(config, "get_filled_dataframe") + bh = BufferHandler(config) + bh.run(h5_paths, folder) + + df = pd.read_parquet(folder) + + assert np.all(list(bh.df_electron.columns) == list(df.columns) + ["dldSectorID"]) + + channel_pulse = get_channels( + config["dataframe"]["channels"], + formats=["per_pulse", "per_train"], + index=True, + extend_aux=True, + ) + assert np.all(list(bh.df_pulse.columns) == channel_pulse) + # remove buffer files + [path.unlink() for path in bh.buffer_paths] diff --git a/tests/loader/flash/test_dataframe_creator.py b/tests/loader/flash/test_dataframe_creator.py new file mode 100644 index 00000000..6125e010 --- /dev/null +++ b/tests/loader/flash/test_dataframe_creator.py @@ -0,0 +1,271 @@ +"""Tests for DataFrameCreator functionality""" +import h5py +import numpy as np +import pytest +from pandas import DataFrame +from pandas import Index +from pandas import MultiIndex + +from sed.loader.flash.dataframe import DataFrameCreator +from sed.loader.flash.utils import get_channels + + +def test_get_index_dataset_key(config_dataframe, h5_paths): + """Test the creation of the index and dataset keys for a given channel.""" + config = config_dataframe + channel = "dldPosX" + df = DataFrameCreator(config, h5_paths[0]) + index_key, dataset_key = df.get_index_dataset_key(channel) + assert index_key == config["channels"][channel]["index_key"] + assert dataset_key == config["channels"][channel]["dataset_key"] + + # remove index_key + del config["channels"][channel]["index_key"] + with pytest.raises(ValueError): + df.get_index_dataset_key(channel) + + +def test_get_dataset_array(config_dataframe, h5_paths): + """Test the creation of a h5py dataset for a given channel.""" + + df = DataFrameCreator(config_dataframe, h5_paths[0]) + channel = "dldPosX" + + train_id, dset = df.get_dataset_array(channel) + # Check that the train_id and np_array have the correct shapes and types + assert isinstance(train_id, Index) + assert isinstance(dset, h5py.Dataset) + assert train_id.name == "trainId" + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 5 + assert dset.shape[2] == 321 + + train_id, dset = df.get_dataset_array(channel, slice_=True) + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 321 + + channel = "gmdTunnel" + train_id, dset = df.get_dataset_array(channel, True) + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 500 + + +def test_empty_get_dataset_array(config_dataframe, h5_paths, h5_file_copy): + """Test the method when given an empty dataset.""" + + channel = "gmdTunnel" + df = DataFrameCreator(config_dataframe, h5_paths[0]) + train_id, dset = df.get_dataset_array(channel) + + channel_index_key = "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/index" + # channel_dataset_key = config_dataframe["channels"][channel]["group_name"] + "value" + empty_dataset_key = "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/empty" + config_dataframe["channels"][channel]["index_key"] = channel_index_key + config_dataframe["channels"][channel]["dataset_key"] = empty_dataset_key + + # create an empty dataset + h5_file_copy.create_dataset( + name=empty_dataset_key, + shape=(train_id.shape[0], 0), + ) + + df = DataFrameCreator(config_dataframe, h5_paths[0]) + df.h5_file = h5_file_copy + train_id, dset_empty = df.get_dataset_array(channel) + + assert dset_empty.shape[0] == train_id.shape[0] + assert dset.shape[1] == 8 + assert dset_empty.shape[1] == 0 + + +def test_pulse_index(config_dataframe, h5_paths): + """Test the creation of the pulse index for electron resolved data""" + + df = DataFrameCreator(config_dataframe, h5_paths[0]) + pulse_index, pulse_array = df.get_dataset_array("pulseId", slice_=True) + index, indexer = df.pulse_index(config_dataframe["ubid_offset"]) + # Check if the index_per_electron is a MultiIndex and has the correct levels + assert isinstance(index, MultiIndex) + assert set(index.names) == {"trainId", "pulseId", "electronId"} + + # Check if the pulse_index has the correct number of elements + # This should be the pulses without nan values + pulse_rav = pulse_array.ravel() + pulse_no_nan = pulse_rav[~np.isnan(pulse_rav)] + assert len(index) == len(pulse_no_nan) + + # Check if all pulseIds are correctly mapped to the index + assert np.all( + index.get_level_values("pulseId").values + == (pulse_no_nan - config_dataframe["ubid_offset"])[indexer], + ) + + assert np.all( + index.get_level_values("electronId").values[:5] == [0, 1, 0, 1, 0], + ) + + assert np.all( + index.get_level_values("electronId").values[-5:] == [1, 0, 1, 0, 1], + ) + + # check if all indexes are unique and monotonic increasing + assert index.is_unique + assert index.is_monotonic_increasing + + +def test_df_electron(config_dataframe, h5_paths): + """Test the creation of a pandas DataFrame for a channel of type [per electron].""" + df = DataFrameCreator(config_dataframe, h5_paths[0]) + + result_df = df.df_electron + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # check that there are no nan values in the dataframe + assert ~result_df.isnull().values.any() + + # Check if first 5 values are as expected + # e.g. that the values are dropped for pulseId index below 0 (ubid_offset) + # however in this data the lowest value is 9 and offset was 5 so no values are dropped + assert np.all( + result_df.values[:5] + == np.array( + [ + [556.0, 731.0, 42888.0], + [549.0, 737.0, 42881.0], + [671.0, 577.0, 39181.0], + [671.0, 579.0, 39196.0], + [714.0, 859.0, 37530.0], + ], + dtype=np.float32, + ), + ) + assert np.all(result_df.index.get_level_values("pulseId") >= 0) + assert isinstance(result_df, DataFrame) + + assert result_df.index.is_unique + + # check that dataframe contains all subchannels + assert np.all( + set(result_df.columns) == set(get_channels(config_dataframe["channels"], ["per_electron"])), + ) + + +def test_create_dataframe_per_pulse(config_dataframe, h5_paths): + """Test the creation of a pandas DataFrame for a channel of type [per pulse].""" + df = DataFrameCreator(config_dataframe, h5_paths[0]) + result_df = df.df_pulse + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + + _, data = df.get_dataset_array("gmdTunnel", slice_=True) + assert result_df.shape[0] == data.shape[0] * data.shape[1] + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # all electronIds should be 0 + assert np.all(result_df.index.get_level_values("electronId") == 0) + + # pulse ids should span 0-499 on each train + for train_id in result_df.index.get_level_values("trainId"): + assert np.all( + result_df.loc[train_id].index.get_level_values("pulseId").values == np.arange(500), + ) + # assert index uniqueness + assert result_df.index.is_unique + + # assert that dataframe contains all channels + assert np.all( + set(result_df.columns) == set(get_channels(config_dataframe["channels"], ["per_pulse"])), + ) + + +def test_create_dataframe_per_train(config_dataframe, h5_paths): + """Test the creation of a pandas DataFrame for a channel of type [per train].""" + df = DataFrameCreator(config_dataframe, h5_paths[0]) + result_df = df.df_train + + channel = "delayStage" + key, data = df.get_dataset_array(channel, slice_=True) + + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + + # check that all values are in the df for delayStage + assert np.all(result_df[channel].dropna() == data[()]) + + # check that dataframe contains all channels + assert np.all( + set(result_df.columns) + == set(get_channels(config_dataframe["channels"], ["per_train"], extend_aux=True)), + ) + + # Ensure DataFrame has rows equal to unique keys from "per_train" channels, considering + # different channels may have data for different trains. This checks the DataFrame's + # completeness and integrity, especially important when channels record at varying trains. + channels = get_channels(config_dataframe["channels"], ["per_train"]) + all_keys = Index([]) + for channel in channels: + # Append unique keys from each channel, considering only training data + all_keys = all_keys.append(df.get_dataset_array(channel, slice_=True)[0]) + # Verify DataFrame's row count matches unique train IDs count across channels + assert result_df.shape[0] == len(all_keys.unique()) + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # all pulseIds and electronIds should be 0 + assert np.all(result_df.index.get_level_values("pulseId") == 0) + assert np.all(result_df.index.get_level_values("electronId") == 0) + + channel = "dldAux" + key, data = df.get_dataset_array(channel, slice_=True) + + # Check if the subchannels are correctly sliced into the dataframe + # The values are stored in DLD which is a 2D array + # The subchannels are stored in the second dimension + # Only index amount of values are stored in the first dimension, the rest are NaNs + # hence the slicing + subchannels = config_dataframe["channels"]["dldAux"]["dldAuxChannels"] + for subchannel, index in subchannels.items(): + assert np.all(df.df_train[subchannel].dropna().values == data[: key.size, index]) + + assert result_df.index.is_unique + + +def test_group_name_not_in_h5(config_dataframe, h5_paths): + """Test ValueError when the group_name for a channel does not exist in the H5 file.""" + channel = "dldPosX" + config = config_dataframe + config["channels"][channel]["dataset_key"] = "foo" + df = DataFrameCreator(config, h5_paths[0]) + + with pytest.raises(KeyError): + df.df_electron + + +def test_create_dataframe_per_file(config_dataframe, h5_paths): + """Test the creation of pandas DataFrames for a given file.""" + df = DataFrameCreator(config_dataframe, h5_paths[0]) + result_df = df.df + + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + all_keys = df.df_train.index.append(df.df_electron.index).append(df.df_pulse.index) + all_keys = all_keys.unique() + assert result_df.shape[0] == len(all_keys.unique()) + + +def test_get_index_dataset_key_error(config_dataframe, h5_paths): + """ + Test that a ValueError is raised when the dataset_key is missing for a channel in the config. + """ + config = config_dataframe + channel = "dldPosX" + df = DataFrameCreator(config, h5_paths[0]) + + del config["channels"][channel]["dataset_key"] + with pytest.raises(ValueError): + df.get_index_dataset_key(channel) diff --git a/tests/loader/flash/test_flash_loader.py b/tests/loader/flash/test_flash_loader.py index 88276107..a6f4a090 100644 --- a/tests/loader/flash/test_flash_loader.py +++ b/tests/loader/flash/test_flash_loader.py @@ -8,7 +8,7 @@ import pytest -from sed.core.config import parse_config +from .test_buffer_handler import create_parquet_dir from sed.loader.flash.loader import FlashLoader package_dir = os.path.dirname(find_spec("sed").origin) @@ -16,80 +16,12 @@ H5_PATH = "FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5" -@pytest.fixture(name="config_file") -def fixture_config_file() -> dict: - """Fixture providing a configuration file for FlashLoader tests. - - Returns: - dict: The parsed configuration file. - """ - return parse_config(config_path) - - -def test_get_channels_by_format(config_file: dict) -> None: - """ - Test function to verify the 'get_channels' method in FlashLoader class for - retrieving channels based on formats and index inclusion. - """ - # Initialize the FlashLoader instance with the given config_file. - fl = FlashLoader(config_file) - - # Define expected channels for each format. - electron_channels = ["dldPosX", "dldPosY", "dldTimeSteps"] - pulse_channels = [ - "sampleBias", - "tofVoltage", - "extractorVoltage", - "extractorCurrent", - "cryoTemperature", - "sampleTemperature", - "dldTimeBinSize", - "gmdTunnel", - ] - train_channels = ["timeStamp", "delayStage"] - index_channels = ["trainId", "pulseId", "electronId"] - - # Call get_channels method with different format options. - - # Request channels for 'per_electron' format using a list. - format_electron = fl.get_channels(["per_electron"]) - - # Request channels for 'per_pulse' format using a string. - format_pulse = fl.get_channels("per_pulse") - - # Request channels for 'per_train' format using a list. - format_train = fl.get_channels(["per_train"]) - - # Request channels for 'all' formats using a list. - format_all = fl.get_channels(["all"]) - - # Request index channels only. - format_index = fl.get_channels(index=True) - - # Request 'per_electron' format and include index channels. - format_index_electron = fl.get_channels(["per_electron"], index=True) - - # Request 'all' formats and include index channels. - format_all_index = fl.get_channels(["all"], index=True) - - # Assert that the obtained channels match the expected channels. - assert set(electron_channels) == set(format_electron) - assert set(pulse_channels) == set(format_pulse) - assert set(train_channels) == set(format_train) - assert set(electron_channels + pulse_channels + train_channels) == set(format_all) - assert set(index_channels) == set(format_index) - assert set(index_channels + electron_channels) == set(format_index_electron) - assert set(index_channels + electron_channels + pulse_channels + train_channels) == set( - format_all_index, - ) - - @pytest.mark.parametrize( "sub_dir", ["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"], ) -def test_initialize_paths( - config_file: dict, +def test_initialize_dirs( + config: dict, fs, sub_dir: Literal["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"], ) -> None: @@ -100,15 +32,15 @@ def test_initialize_paths( fs: A fixture for a fake file system. sub_dir (Literal["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"]): Sub-directory. """ - config = config_file - del config["core"]["paths"] - config["core"]["beamtime_id"] = "12345678" - config["core"]["year"] = "2000" + config_ = config.copy() + del config_["core"]["paths"] + config_["core"]["beamtime_id"] = "12345678" + config_["core"]["year"] = "2000" # Find base path of beamline from config. Here, we use pg2 - base_path = config["dataframe"]["beamtime_dir"]["pg2"] + base_path = config_["dataframe"]["beamtime_dir"]["pg2"] expected_path = ( - Path(base_path) / config["core"]["year"] / "data" / config["core"]["beamtime_id"] + Path(base_path) / config_["core"]["year"] / "data" / config_["core"]["beamtime_id"] ) # Create expected paths expected_raw_path = expected_path / "raw" / "hdf" / sub_dir @@ -118,110 +50,179 @@ def test_initialize_paths( fs.create_dir(expected_raw_path) fs.create_dir(expected_processed_path) - # Instance of class with correct config and call initialize_paths - fl = FlashLoader(config=config) - data_raw_dir, data_parquet_dir = fl.initialize_paths() + # Instance of class with correct config and call initialize_dirs + fl = FlashLoader(config=config_) + fl._initialize_dirs() + assert str(expected_raw_path) == fl.raw_dir + assert str(expected_processed_path) == fl.parquet_dir - assert expected_raw_path == data_raw_dir[0] - assert expected_processed_path == data_parquet_dir + # remove breamtimeid, year and daq from config to raise error + del config_["core"]["beamtime_id"] + with pytest.raises(ValueError) as e: + fl._initialize_dirs() + assert "The beamtime_id and year are required." in str(e.value) -def test_initialize_paths_filenotfound(config_file: dict) -> None: +def test_initialize_dirs_filenotfound(config: dict) -> None: """ Test FileNotFoundError during the initialization of paths. """ # Test the FileNotFoundError - config = config_file - del config["core"]["paths"] - config["core"]["beamtime_id"] = "11111111" - config["core"]["year"] = "2000" + config_ = config.copy() + del config_["core"]["paths"] + config_["core"]["beamtime_id"] = "11111111" + config_["core"]["year"] = "2000" - # Instance of class with correct config and call initialize_paths - fl = FlashLoader(config=config) + # Instance of class with correct config and call initialize_dirs with pytest.raises(FileNotFoundError): - _, _ = fl.initialize_paths() + fl = FlashLoader(config=config_) + fl._initialize_dirs() -def test_invalid_channel_format(config_file: dict) -> None: +def test_save_read_parquet_flash(config): """ - Test ValueError for an invalid channel format. + Test the functionality of saving and reading parquet files with FlashLoader. + + This test performs three main actions: + 1. First call to create and read parquet files. Verifies new files are created. + 2. Second call with the same parameters to check that it only reads from + the existing parquet files without creating new ones. It asserts that the files' modification + times remain unchanged, indicating no new files were created or existing files overwritten. + 3. Third call with `force_recreate=True` to force the recreation of parquet files. + It verifies that the files were indeed overwritten by checking that their modification + times have changed. """ - config = config_file - config["dataframe"]["channels"]["dldPosX"]["format"] = "foo" + config_ = config.copy() + data_parquet_dir = create_parquet_dir(config_, "flash_save_read") + config_["core"]["paths"]["data_parquet_dir"] = data_parquet_dir + fl = FlashLoader(config=config_) + + # First call: should create and read the parquet file + df1, _, _ = fl.read_dataframe(runs=[43878, 43879]) + # Check if new files were created + data_parquet_dir = data_parquet_dir.joinpath("buffer") + new_files = { + file: os.path.getmtime(data_parquet_dir.joinpath(file)) + for file in os.listdir(data_parquet_dir) + } + assert new_files - fl = FlashLoader(config=config) + # Second call: should only read the parquet file, not create new ones + df2, _, _ = fl.read_dataframe(runs=[43878, 43879]) - with pytest.raises(ValueError): - fl.read_dataframe() + # Verify no new files were created after the second call + final_files = { + file: os.path.getmtime(data_parquet_dir.joinpath(file)) + for file in os.listdir(data_parquet_dir) + } + assert ( + new_files == final_files + ), "Files were overwritten or new files were created after the second call." + # Third call: We force_recreate the parquet files + df3, _, _ = fl.read_dataframe(runs=[43878, 43879], force_recreate=True) -def test_group_name_not_in_h5(config_file: dict) -> None: - """ - Test ValueError when the group_name for a channel does not exist in the H5 file. - """ - config = config_file - config["dataframe"]["channels"]["dldPosX"]["group_name"] = "foo" - fl = FlashLoader(config=config) + # Verify files were overwritten + new_files = { + file: os.path.getmtime(data_parquet_dir.joinpath(file)) + for file in os.listdir(data_parquet_dir) + } + assert new_files != final_files, "Files were not overwritten after the third call." - with pytest.raises(ValueError) as e: - fl.create_dataframe_per_file(Path(config["core"]["paths"]["data_raw_dir"] + H5_PATH)) + # remove the parquet files + [data_parquet_dir.joinpath(file).unlink() for file in new_files] - assert str(e.value.args[0]) == "The group_name for channel dldPosX does not exist." +def test_get_elapsed_time_fid(config): + """Test get_elapsed_time method of FlashLoader class""" + # Create an instance of FlashLoader + fl = FlashLoader(config=config) -def test_buffer_schema_mismatch(config_file: dict) -> None: - """ - Test function to verify schema mismatch handling in the FlashLoader's 'read_dataframe' method. - - The test validates the error handling mechanism when the available channels do not match the - schema of the existing parquet files. - - Test Steps: - - Attempt to read a dataframe after adding a new channel 'gmdTunnel2' to the configuration. - - Check for an expected error related to the mismatch between available channels and schema. - - Force recreation of dataframe with the added channel, ensuring successful dataframe creation. - - Simulate a missing channel scenario by removing 'gmdTunnel2' from the configuration. - - Check for an error indicating a missing channel in the configuration. - - Clean up created buffer files after the test. - """ - fl = FlashLoader(config=config_file) + # Mock the file_statistics and files + fl.metadata = { + "file_statistics": { + "0": {"time_stamps": [10, 20]}, + "1": {"time_stamps": [20, 30]}, + "2": {"time_stamps": [30, 40]}, + }, + } + fl.files = ["file0", "file1", "file2"] + + # Test get_elapsed_time with fids + assert fl.get_elapsed_time(fids=[0, 1]) == 20 + + # # Test get_elapsed_time with runs + # # Assuming get_files_from_run_id(43878) returns ["file0", "file1"] + # assert fl.get_elapsed_time(runs=[43878]) == 20 + + # Test get_elapsed_time with aggregate=False + assert fl.get_elapsed_time(fids=[0, 1], aggregate=False) == [10, 10] + + # Test KeyError when file_statistics is missing + fl.metadata = {"something": "else"} + with pytest.raises(KeyError) as e: + fl.get_elapsed_time(fids=[0, 1]) + + assert "File statistics missing. Use 'read_dataframe' first." in str(e.value) + # Test KeyError when time_stamps is missing + fl.metadata = { + "file_statistics": { + "0": {}, + "1": {"time_stamps": [20, 30]}, + }, + } + with pytest.raises(KeyError) as e: + fl.get_elapsed_time(fids=[0, 1]) + + assert "Timestamp metadata missing in file 0" in str(e.value) - # Read a dataframe for a specific run - fl.read_dataframe(runs=["43878"]) - # Manipulate the configuration to introduce a new channel 'gmdTunnel2' - config = config_file - config["dataframe"]["channels"]["gmdTunnel2"] = { - "group_name": "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/", - "format": "per_pulse", +def test_get_elapsed_time_run(config): + config_ = config.copy() + config_["core"]["paths"] = { + "data_raw_dir": "tests/data/loader/flash/", + "data_parquet_dir": "tests/data/loader/flash/parquet/get_elapsed_time_run", } + """Test get_elapsed_time method of FlashLoader class""" + # Create an instance of FlashLoader + fl = FlashLoader(config=config_) - # Reread the dataframe with the modified configuration, expecting a schema mismatch error - fl = FlashLoader(config=config) - with pytest.raises(ValueError) as e: - fl.read_dataframe(runs=["43878"]) - expected_error = e.value.args + fl.read_dataframe(runs=[43878, 43879]) + start, end = fl.metadata["file_statistics"]["0"]["time_stamps"] + expected_elapsed_time_0 = end - start + start, end = fl.metadata["file_statistics"]["1"]["time_stamps"] + expected_elapsed_time_1 = end - start + + elapsed_time = fl.get_elapsed_time(runs=[43878]) + assert elapsed_time == expected_elapsed_time_0 - # Validate the specific error messages for schema mismatch - assert "The available channels do not match the schema of file" in expected_error[0] - assert expected_error[2] == "Missing in parquet: {'gmdTunnel2'}" - assert expected_error[4] == "Please check the configuration file or set force_recreate to True." + elapsed_time = fl.get_elapsed_time(runs=[43878, 43879], aggregate=False) + assert elapsed_time == [expected_elapsed_time_0, expected_elapsed_time_1] - # Force recreation of the dataframe, including the added channel 'gmdTunnel2' - fl.read_dataframe(runs=["43878"], force_recreate=True) + elapsed_time = fl.get_elapsed_time(runs=[43878, 43879]) + start, end = fl.metadata["file_statistics"]["1"]["time_stamps"] + assert elapsed_time == expected_elapsed_time_0 + expected_elapsed_time_1 - # Remove 'gmdTunnel2' from the configuration to simulate a missing channel scenario - del config["dataframe"]["channels"]["gmdTunnel2"] + +def test_available_runs(monkeypatch, config): + """Test available_runs property of FlashLoader class""" + # Create an instance of FlashLoader fl = FlashLoader(config=config) - with pytest.raises(ValueError) as e: - # Attempt to read the dataframe again to check for the missing channel error - fl.read_dataframe(runs=["43878"]) - expected_error = e.value.args - # Check for the specific error message indicating a missing channel in the configuration - assert expected_error[3] == "Missing in config: {'gmdTunnel2'}" + # Mock the raw_dir and files + fl.raw_dir = "/path/to/raw_dir" + files = [ + "run1_file1.h5", + "run3_file1.h5", + "run2_file1.h5", + "run1_file2.h5", + ] + + # Mock the glob method to return the mock files + def mock_glob(*args, **kwargs): # noqa: ARG001 + return [Path(fl.raw_dir, file) for file in files] + + monkeypatch.setattr(Path, "glob", mock_glob) - # Clean up created buffer files after the test - _, parquet_data_dir = fl.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + # Test available_runs + assert fl.available_runs == [1, 2, 3] diff --git a/tests/loader/flash/test_utils.py b/tests/loader/flash/test_utils.py new file mode 100644 index 00000000..08891320 --- /dev/null +++ b/tests/loader/flash/test_utils.py @@ -0,0 +1,110 @@ +"""Tests for utils functionality""" +import pytest + +from .test_buffer_handler import create_parquet_dir +from sed.loader.flash.utils import get_channels +from sed.loader.flash.utils import initialize_paths + +# Define expected channels for each format. +ELECTRON_CHANNELS = ["dldPosX", "dldPosY", "dldTimeSteps"] +PULSE_CHANNELS = ["pulserSignAdc", "gmdTunnel"] +TRAIN_CHANNELS = ["timeStamp", "delayStage", "dldAux"] +TRAIN_CHANNELS_EXTENDED = [ + "sampleBias", + "tofVoltage", + "extractorVoltage", + "extractorCurrent", + "cryoTemperature", + "sampleTemperature", + "dldTimeBinSize", + "timeStamp", + "delayStage", +] +INDEX_CHANNELS = ["trainId", "pulseId", "electronId"] + + +def test_get_channels_by_format(config_dataframe): + """ + Test function to verify the 'get_channels' method in FlashLoader class for + retrieving channels based on formats and index inclusion. + """ + # Initialize the FlashLoader instance with the given config_file. + ch_dict = config_dataframe["channels"] + + # Call get_channels method with different format options. + + # Request channels for 'per_electron' format using a list. + format_electron = get_channels(ch_dict, ["per_electron"]) + + # Request channels for 'per_pulse' format using a string. + format_pulse = get_channels(ch_dict, "per_pulse") + + # Request channels for 'per_train' format without expanding the dldAuxChannels. + format_train = get_channels(ch_dict, "per_train", extend_aux=False) + + # Request channels for 'per_train' format using a list, and expand the dldAuxChannels. + format_train_extended = get_channels(ch_dict, ["per_train"], extend_aux=True) + + # Request channels for 'all' formats using a list. + format_all = get_channels(ch_dict, ["all"]) + + # Request index channels only. No need for channel_dict. + format_index = get_channels(index=True) + + # Request 'per_electron' format and include index channels. + format_index_electron = get_channels(ch_dict, ["per_electron"], index=True) + + # Request 'all' formats and include index channels. + format_all_index = get_channels(ch_dict, ["all"], index=True) + + # Request 'all' formats and include index channels and extend aux channels + format_all_index_extend_aux = get_channels(ch_dict, ["all"], index=True, extend_aux=True) + + # Assert that the obtained channels match the expected channels. + assert set(ELECTRON_CHANNELS) == set(format_electron) + assert set(TRAIN_CHANNELS_EXTENDED) == set(format_train_extended) + assert set(TRAIN_CHANNELS) == set(format_train) + assert set(PULSE_CHANNELS) == set(format_pulse) + assert set(ELECTRON_CHANNELS + TRAIN_CHANNELS + PULSE_CHANNELS) == set(format_all) + assert set(INDEX_CHANNELS) == set(format_index) + assert set(INDEX_CHANNELS + ELECTRON_CHANNELS) == set(format_index_electron) + assert set(INDEX_CHANNELS + ELECTRON_CHANNELS + TRAIN_CHANNELS + PULSE_CHANNELS) == set( + format_all_index, + ) + assert set( + INDEX_CHANNELS + ELECTRON_CHANNELS + PULSE_CHANNELS + TRAIN_CHANNELS_EXTENDED, + ) == set( + format_all_index_extend_aux, + ) + + +def test_parquet_init_error(): + """Test ParquetHandler initialization error""" + with pytest.raises(ValueError) as e: + _ = initialize_paths(filenames="test") + + assert "Please provide folder or paths." in str(e.value) + + with pytest.raises(ValueError) as e: + _ = initialize_paths(folder="test") + + assert "With folder, please provide filenames." in str(e.value) + + +def test_initialize_paths(config): + """Test ParquetHandler initialization""" + folder = create_parquet_dir(config, "parquet_init") + + ph = initialize_paths("test", folder, extension="xyz") + assert ph[0].suffix == ".xyz" + assert ph[0].name == "test.xyz" + + # test prefix and suffix + ph = initialize_paths("test", folder, prefix="prefix", suffix="suffix") + assert ph[0].name == "prefix_test_suffix.parquet" + + # test with list of parquet_names and subfolder + ph = initialize_paths(["test1", "test2"], folder, subfolder="subfolder") + assert ph[0].parent.name == "subfolder" + assert ph[0].name == "test1.parquet" + assert ph[1].name == "test2.parquet" diff --git a/tests/loader/sxp/test_sxp_loader.py b/tests/loader/sxp/test_sxp_loader.py index 83ade005..cc8698a3 100644 --- a/tests/loader/sxp/test_sxp_loader.py +++ b/tests/loader/sxp/test_sxp_loader.py @@ -74,7 +74,7 @@ def test_get_channels_by_format(config_file: dict) -> None: ) -def test_initialize_paths(config_file: dict, fs) -> None: +def test_initialize_dirs(config_file: dict, fs) -> None: """ Test the initialization of paths based on the configuration and directory structures. @@ -97,15 +97,15 @@ def test_initialize_paths(config_file: dict, fs) -> None: fs.create_dir(expected_raw_path) fs.create_dir(expected_processed_path) - # Instance of class with correct config and call initialize_paths + # Instance of class with correct config and call initialize_dirs sl = SXPLoader(config=config) - data_raw_dir, data_parquet_dir = sl.initialize_paths() + sl._initialize_dirs() - assert expected_raw_path == data_raw_dir[0] - assert expected_processed_path == data_parquet_dir + assert expected_raw_path == sl.raw_dir[0] + assert expected_processed_path == sl.parquet_dir -def test_initialize_paths_filenotfound(config_file: dict): +def test_initialize_dirs_filenotfound(config_file: dict): """ Test FileNotFoundError during the initialization of paths. """ @@ -115,10 +115,10 @@ def test_initialize_paths_filenotfound(config_file: dict): config["core"]["beamtime_id"] = "11111111" config["core"]["year"] = "2000" - # Instance of class with correct config and call initialize_paths + # Instance of class with correct config and call initialize_dirs sl = SXPLoader(config=config) with pytest.raises(FileNotFoundError): - _, _ = sl.initialize_paths() + sl._initialize_dirs() def test_invalid_channel_format(config_file: dict): @@ -209,6 +209,6 @@ def test_buffer_schema_mismatch(config_file: dict): assert expected_error[3] == "Missing in config: {'delayStage2'}" # Clean up created buffer files after the test - _, parquet_data_dir = sl.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + sl._initialize_dirs() + for file in os.listdir(Path(sl.parquet_dir, "buffer")): + os.remove(Path(sl.parquet_dir, "buffer", file)) diff --git a/tests/loader/test_loaders.py b/tests/loader/test_loaders.py index 8958b1a1..020e926b 100644 --- a/tests/loader/test_loaders.py +++ b/tests/loader/test_loaders.py @@ -164,9 +164,9 @@ def test_has_correct_read_dataframe_func(loader: BaseLoader, read_type: str) -> if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) @pytest.mark.parametrize("loader", get_all_loaders()) @@ -197,9 +197,9 @@ def test_timed_dataframe(loader: BaseLoader) -> None: if loaded_timed_dataframe is None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) pytest.skip("Not implemented") assert isinstance(loaded_timed_dataframe, ddf.DataFrame) assert set(loaded_timed_dataframe.columns).issubset(set(loaded_dataframe.columns)) @@ -207,9 +207,9 @@ def test_timed_dataframe(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) @pytest.mark.parametrize("loader", get_all_loaders()) @@ -241,9 +241,9 @@ def test_get_count_rate(loader: BaseLoader) -> None: if loaded_time is None and loaded_countrate is None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) pytest.skip("Not implemented") assert len(loaded_time) == len(loaded_countrate) loaded_time2, loaded_countrate2 = loader.get_count_rate(fids=[0]) @@ -252,9 +252,9 @@ def test_get_count_rate(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) @pytest.mark.parametrize("loader", get_all_loaders()) @@ -284,11 +284,11 @@ def test_get_elapsed_time(loader: BaseLoader) -> None: ) elapsed_time = loader.get_elapsed_time() if elapsed_time is None: - if loader.__name__ in {"flash", "sxp"}: + if loader.__name__ in {"sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) pytest.skip("Not implemented") assert elapsed_time > 0 elapsed_time2 = loader.get_elapsed_time(fids=[0]) @@ -297,9 +297,9 @@ def test_get_elapsed_time(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + loader._initialize_dirs() + for file in os.listdir(Path(loader.parquet_dir, "buffer")): + os.remove(Path(loader.parquet_dir, "buffer", file)) def test_mpes_timestamps() -> None: diff --git a/tests/test_processor.py b/tests/test_processor.py index 1514dd3f..7865afd2 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -644,7 +644,6 @@ def test_align_dld_sectors() -> None: assert "dldSectorID" in processor.dataframe.columns sector_delays = np.asarray([10, -10, 20, -20, 30, -30, 40, -40]) - tof_ref = [] for i in range(len(sector_delays)): tof_ref.append(