diff --git a/.gitignore b/.gitignore index 5cc3ff6b..f41a5234 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,6 @@ dmypy.json # IDE stuff \.vscode + +# Mac stuff +.DS_Store diff --git a/sed/loader/fel/__init__.py b/sed/loader/fel/__init__.py new file mode 100644 index 00000000..1141faa4 --- /dev/null +++ b/sed/loader/fel/__init__.py @@ -0,0 +1,11 @@ +"""sed.loader.fel module easy access APIs +""" +from .buffer import BufferHandler +from .dataframe import DataFrameCreator +from .parquet import ParquetHandler + +__all__ = [ + "BufferHandler", + "DataFrameCreator", + "ParquetHandler", +] diff --git a/sed/loader/fel/buffer.py b/sed/loader/fel/buffer.py new file mode 100644 index 00000000..edd8cdb2 --- /dev/null +++ b/sed/loader/fel/buffer.py @@ -0,0 +1,249 @@ +""" +The BufferFileHandler uses the DataFrameCreator class and uses the ParquetHandler class to +manage buffer files. It provides methods for initializing paths, checking the schema, +determining the list of files to read, serializing and parallelizing the creation, and reading +all files into one Dask DataFrame. + +After initialization, the electron and timed dataframes can be accessed as: + + buffer_handler = BufferFileHandler(config, h5_paths, folder) + + buffer_handler.electron_dataframe + buffer_handler.pulse_dataframe + +Force_recreate flag forces recreation of buffer files. Useful when the schema has changed. +Debug mode serializes the creation of buffer files. +""" +from __future__ import annotations + +from itertools import compress +from pathlib import Path +from typing import Type + +import dask.dataframe as ddf +import h5py +import pyarrow.parquet as pq +from joblib import delayed +from joblib import Parallel + +from sed.core.dfops import forward_fill_lazy +from sed.loader.fel.config_model import DataFrameConfig +from sed.loader.fel.dataframe import DataFrameCreator +from sed.loader.fel.parquet import ParquetHandler +from sed.loader.fel.utils import get_channels +from sed.loader.utils import split_dld_time_from_sector_id + + +class BufferHandler: + """ + A class for handling the creation and manipulation of buffer files using DataFrameCreator + and ParquetHandler. + """ + + def __init__( + self, + df_creator: type[DataFrameCreator], + config: DataFrameConfig, + h5_paths: list[Path], + folder: Path, + force_recreate: bool = False, + prefix: str = "", + suffix: str = "", + debug: bool = False, + auto: bool = True, + ) -> None: + """ + Initializes the BufferFileHandler. + + Args: + df_creator (Type[DataFrameCreator]): Derived class based on DataFrameCreator. + config (DataFrameConfig): The dataframe section of the config model. + h5_paths (List[Path]): List of paths to H5 files. + folder (Path): Path to the folder for buffer files. + force_recreate (bool): Flag to force recreation of buffer files. + prefix (str): Prefix for buffer file names. + suffix (str): Suffix for buffer file names. + debug (bool): Flag to enable debug mode. + auto (bool): Flag to automatically create buffer files and fill the dataframe. + """ + self.df_creator = df_creator + self._config = config + + self.buffer_paths: list[Path] = [] + self.h5_to_create: list[Path] = [] + self.buffer_to_create: list[Path] = [] + + self.dataframe_electron: ddf.DataFrame = None + self.dataframe_pulse: ddf.DataFrame = None + + # In auto mode, these methods are called automatically + if auto: + self.get_files_to_read(h5_paths, folder, prefix, suffix, force_recreate) + + if not force_recreate: + self.schema_check() + + self.create_buffer_files(debug) + + self.get_filled_dataframe() + + def schema_check(self) -> None: + """ + Checks the schema of the Parquet files. + + Raises: + ValueError: If the schema of the Parquet files does not match the configuration. + """ + existing_parquet_filenames = [file for file in self.buffer_paths if file.exists()] + parquet_schemas = [pq.read_schema(file) for file in existing_parquet_filenames] + config_schema = set( + get_channels(self._config.channels, formats="all", index=True, extend_aux=True), + ) + + for i, schema in enumerate(parquet_schemas): + schema_set = set(schema.names) + if schema_set != config_schema: + missing_in_parquet = config_schema - schema_set + missing_in_config = schema_set - config_schema + + missing_in_parquet_str = ( + f"Missing in parquet: {missing_in_parquet}" if missing_in_parquet else "" + ) + missing_in_config_str = ( + f"Missing in config: {missing_in_config}" if missing_in_config else "" + ) + + raise ValueError( + "The available channels do not match the schema of file", + f"{existing_parquet_filenames[i]}", + f"{missing_in_parquet_str}", + f"{missing_in_config_str}", + "Please check the configuration file or set force_recreate to True.", + ) + + def get_files_to_read( + self, + h5_paths: list[Path], + folder: Path, + prefix: str, + suffix: str, + force_recreate: bool, + ) -> None: + """ + Determines the list of files to read and the corresponding buffer files to create. + + Args: + h5_paths (List[Path]): List of paths to H5 files. + folder (Path): Path to the folder for buffer files. + prefix (str): Prefix for buffer file names. + suffix (str): Suffix for buffer file names. + force_recreate (bool): Flag to force recreation of buffer files. + """ + # Getting the paths of the buffer files, with subfolder as buffer and no extension + pq_handler = ParquetHandler( + [Path(h5_path).stem for h5_path in h5_paths], + folder, + "buffer", + prefix, + suffix, + extension="", + ) + self.buffer_paths = pq_handler.parquet_paths + # read only the files that do not exist or if force_recreate is True + files_to_read = [ + force_recreate or not parquet_path.exists() for parquet_path in self.buffer_paths + ] + + # Get the list of H5 files to read and the corresponding buffer files to create + self.h5_to_create = list(compress(h5_paths, files_to_read)) + self.buffer_to_create = list(compress(self.buffer_paths, files_to_read)) + + self.num_files = len(self.h5_to_create) + + print(f"Reading files: {self.num_files} new files of {len(h5_paths)} total.") + + def _create_buffer_file(self, h5_path: Path, parquet_path: Path) -> None: + """ + Creates a single buffer file. Useful because h5py.File cannot be pickled if left open. + + Args: + h5_path (Path): Path to the H5 file. + parquet_path (Path): Path to the buffer file. + """ + # Open the h5 file in read mode + h5_file = h5py.File(h5_path, "r") + + # Create a DataFrameCreator instance with the configuration and the h5 file + dfc = self.df_creator(self._config, h5_file) + + # Get the DataFrame from the DataFrameCreator instance + df = dfc.df + + # Close the h5 file + h5_file.close() + + # Reset the index of the DataFrame and save it as a parquet file + df.reset_index().to_parquet(parquet_path) + + def create_buffer_files(self, debug: bool) -> None: + """ + Creates the buffer files. + + Args: + debug (bool): Flag to enable debug mode, which serializes the creation. + """ + if self.num_files > 0: + if debug: + for h5_path, parquet_path in zip(self.h5_to_create, self.buffer_to_create): + self._create_buffer_file(h5_path, parquet_path) + else: + Parallel(n_jobs=self.num_files, verbose=10)( + delayed(self._create_buffer_file)(h5_path, parquet_path) + for h5_path, parquet_path in zip(self.h5_to_create, self.buffer_to_create) + ) + + def get_filled_dataframe(self) -> None: + """ + Reads all parquet files into one dataframe using dask and fills NaN values. + """ + dataframe = ddf.read_parquet(self.buffer_paths, calculate_divisions=True) + metadata = [pq.read_metadata(file) for file in self.buffer_paths] + + channels: list[str] = get_channels( + self._config.channels, + ["per_pulse", "per_train"], + extend_aux=True, + ) + index: list[str] = get_channels(index=True) + overlap = min(file.num_rows for file in metadata) + + print("Filling nan values...") + dataframe = forward_fill_lazy( + df=dataframe, + columns=channels, + before=overlap, + iterations=self._config.forward_fill_iterations, + ) + + # Drop rows with nan values in the tof column + dataframe_electron = dataframe.dropna(subset=self._config.tof_column) + + # Set the dtypes of the channels here as there should be no null values + ch_names = get_channels(self._config.channels, "all") + cfg_ch = self._config.channels + dtypes = { + channel: cfg_ch[channel].dtype + for channel in ch_names + if cfg_ch[channel].dtype is not None + } + + # Correct the 3-bit shift which encodes the detector ID in the 8s time + if self._config.split_sector_id_from_dld_time: + dataframe_electron = split_dld_time_from_sector_id( + dataframe_electron, + self._config.tof_column, + self._config.sector_id_column, + self._config.sector_id_reserved_bits, + ) + self.dataframe_electron = dataframe_electron.astype(dtypes) + self.dataframe_pulse = dataframe[index + channels] diff --git a/sed/loader/fel/config_model.py b/sed/loader/fel/config_model.py new file mode 100644 index 00000000..f8c11c52 --- /dev/null +++ b/sed/loader/fel/config_model.py @@ -0,0 +1,263 @@ +from enum import Enum +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Optional + +from pydantic import BaseModel +from pydantic import DirectoryPath +from pydantic import field_validator +from pydantic import model_validator + + +class DataFormat(str, Enum): + PER_ELECTRON = "per_electron" + PER_PULSE = "per_pulse" + PER_TRAIN = "per_train" + + +class DataPaths(BaseModel): + """ + Represents paths for raw and parquet data in a beamtime directory. + """ + + data_raw_dir: DirectoryPath + data_parquet_dir: DirectoryPath + + @field_validator("data_parquet_dir", mode="before") + @classmethod + def check_and_create_parquet_dir(cls, v): + v = Path(v) + if not v.is_dir(): + v.mkdir(parents=True, exist_ok=True) + return v + + @classmethod + def from_beamtime_dir( + cls, + loader: str, + beamtime_dir: Path, + beamtime_id: int, + year: int, + daq: str, + ) -> "DataPaths": + """ + Create DataPaths instance from a beamtime directory and DAQ type. + + Parameters: + - beamtime_dir (Path): Path to the beamtime directory. + - daq (str): Type of DAQ. + + Returns: + - DataPaths: Instance of DataPaths. + """ + data_raw_dir_list = [] + if loader == "flash": + beamtime_dir = beamtime_dir.joinpath(f"{year}/data/{beamtime_id}/") + raw_path = beamtime_dir.joinpath("raw") + + for path in raw_path.glob("**/*"): + if path.is_dir(): + dir_name = path.name + if dir_name.startswith("express-") or dir_name.startswith( + "online-", + ): + data_raw_dir_list.append(path.joinpath(daq)) + elif dir_name == daq.upper(): + data_raw_dir_list.append(path) + data_raw_dir = data_raw_dir_list[0] + + if loader == "sxp": + beamtime_dir = beamtime_dir.joinpath(f"{year}/{beamtime_id}/") + data_raw_dir = beamtime_dir.joinpath("raw") + + if not data_raw_dir.is_dir(): + raise FileNotFoundError("Raw data directories not found.") + + parquet_path = "processed/parquet" + data_parquet_dir = beamtime_dir.joinpath(parquet_path) + data_parquet_dir.mkdir(parents=True, exist_ok=True) + + return cls(data_raw_dir=data_raw_dir, data_parquet_dir=data_parquet_dir) + + +class AuxiliaryChannel(BaseModel): + """ + Represents auxiliary channels in DLD. + """ + + name: str + slice: int + dtype: Optional[str] = None + + +class Channel(BaseModel): + """ + Represents a data channel. + """ + + name: str + format: DataFormat + group_name: Optional[str] = None + index_key: Optional[str] = None + dataset_key: Optional[str] = None + slice: Optional[int] = None + dtype: Optional[str] = None + dldAuxChannels: Optional[dict] = None + max_hits: Optional[int] = None + scale: Optional[float] = None + + @model_validator(mode="after") + def set_index_dataset_key(self): + if self.index_key and self.dataset_key: + return self + if self.group_name: + self.index_key = self.group_name + "index" + if self.name == "timeStamp": + self.dataset_key = self.group_name + "time" + else: + self.dataset_key = self.group_name + "value" + else: + raise ValueError( + "Channel:", + self.name, + "Either 'group_name' or 'index_key' AND 'dataset_key' must be provided.", + ) + return self + + # if name is dldAux, check format to be per_train. If per_pulse, correct to per_train + @model_validator(mode="after") + def dldAux_format(self): + if self.name == "dldAux": + if self.format != DataFormat.PER_TRAIN: + print("The correct format for dldAux is per_train, not per_pulse. Correcting.") + self.format = DataFormat.PER_TRAIN + return self + + # validate dldAuxChannels + @model_validator(mode="after") + def check_dldAuxChannels(self): + if self.name == "dldAux": + if self.dldAuxChannels is None: + raise ValueError(f"Channel 'dldAux' requires 'dldAuxChannels' to be defined.") + for name, data in self.dldAuxChannels.items(): + # if data is int, convert to dict + if isinstance(data, int): + self.dldAuxChannels[name] = AuxiliaryChannel(name=name, slice=data) + elif isinstance(data, dict): + self.dldAuxChannels[name] = AuxiliaryChannel(name=name, **data) + + return self + + +class DataFrameConfig(BaseModel): + """ + Represents configuration for DataFrame. + """ + + daq: str + ubid_offset: int + forward_fill_iterations: int = 2 + split_sector_id_from_dld_time: bool = False + sector_id_reserved_bits: Optional[int] = None + channels: Dict[str, Any] + units: Dict[str, str] = None + stream_name_prefixes: Dict[str, str] = None + stream_name_prefix: str = "" + stream_name_postfixes: Dict[str, str] = None + stream_name_postfix: str = "" + beamtime_dir: Dict[str, str] + sector_id_column: Optional[str] = None + tof_column: Optional[str] = "dldTimeSteps" + num_trains: Optional[int] = None + + @field_validator("channels", mode="before") + @classmethod + def populate_channels(cls, v): + return {name: Channel(name=name, **data) for name, data in v.items()} + + # validate that pulseId + @field_validator("channels", mode="after") + @classmethod + def check_channels(cls, v): + if "pulseId" not in v: + raise ValueError("Channel: pulseId must be provided.") + return v + + # valide split_sector_id_from_dld_time and sector_id_reserved_bits + @model_validator(mode="after") + def check_sector_id_reserved_bits(self): + if self.split_sector_id_from_dld_time: + if self.sector_id_reserved_bits is None: + raise ValueError( + "'split_sector_id_from_dld_time' is True", + "Please provide 'sector_id_reserved_bits'.", + ) + if self.sector_id_column is None: + print("No sector_id_column provided. Defaulting to dldSectorID.") + self.sector_id_column = "dldSectorID" + return self + + # compute stream_name_prefix and stream_name_postfix + @model_validator(mode="after") + def set_stream_name_prefix_and_postfix(self): + if self.stream_name_prefixes is not None: + # check if daq is in stream_name_prefixes + if self.daq not in self.stream_name_prefixes: + raise ValueError( + f"DAQ type '{self.daq}' not found in stream_name_prefixes.", + ) + self.stream_name_prefix = self.stream_name_prefixes[self.daq] + + if self.stream_name_postfixes is not None: + # check if daq is in stream_name_postfixes + if self.daq not in self.stream_name_postfixes: + raise ValueError( + f"DAQ type '{self.daq}' not found in stream_name_postfixes.", + ) + self.stream_name_postfix = self.stream_name_postfixes[self.daq] + + return self + + +class CoreConfig(BaseModel): + """ + Represents core configuration for Flash. + """ + + loader: str = None + beamline: str = None + paths: Optional[DataPaths] = None + beamtime_id: int = None + year: int = None + base_folder: Optional[str] = None + + +class LoaderConfig(BaseModel): + """ + Configuration for the flash loader. + """ + + core: CoreConfig + dataframe: DataFrameConfig + metadata: Optional[Dict] = None + nexus: Optional[Dict] = None + + @model_validator(mode="after") + def check_paths(self): + if self.core.paths is None: + # check if beamtime_id and year are set + if self.core.beamtime_id is None or self.core.year is None: + raise ValueError("Either 'paths' or 'beamtime_id' and 'year' must be provided.") + + daq = self.dataframe.daq + beamtime_dir_path = Path(self.dataframe.beamtime_dir[self.core.beamline]) + self.core.paths = DataPaths.from_beamtime_dir( + self.core.loader, + beamtime_dir_path, + self.core.beamtime_id, + self.core.year, + daq, + ) + + return self diff --git a/sed/loader/fel/dataframe.py b/sed/loader/fel/dataframe.py new file mode 100644 index 00000000..e7c0fc9e --- /dev/null +++ b/sed/loader/fel/dataframe.py @@ -0,0 +1,196 @@ +""" +This module provides functionality for creating pandas DataFrames from HDF5 files with multiple +channels, found using get_channels method. + +The DataFrameCreator class requires a configuration dictionary with only the dataframe key and +an open h5 file. It validates if provided [index and dataset keys] or [group_name key] has +groups existing in the h5 file. +Three formats of channels are supported: [per electron], [per pulse], and [per train]. +These can be accessed using the df_electron, df_pulse, and df_train properties respectively. +The combined DataFrame can be accessed using the df property. +Typical usage example: + + df_creator = DataFrameCreator(config, h5_file) + dataframe = df_creator.df +""" +from __future__ import annotations + +from functools import reduce + +import h5py +import numpy as np +from pandas import concat +from pandas import DataFrame +from pandas import Index +from pandas import MultiIndex +from pandas import Series + +from sed.loader.fel.config_model import DataFrameConfig +from sed.loader.fel.utils import get_channels + + +class DataFrameCreator: + """ + Utility class for creating pandas DataFrames from HDF5 files with multiple channels. + """ + + def __init__(self, config: DataFrameConfig, h5_file: h5py.File) -> None: + """ + Initializes the DataFrameCreator class. + + Args: + config (DataFrameConfig): The dataframe section of the config model. + h5_file (h5py.File): The open h5 file. + """ + self.h5_file: h5py.File = h5_file + self.failed_files_error: list[str] = [] + self.multi_index = get_channels(index=True) + self._config = config + + def get_dataset_array( + self, + channel: str, + slice_: bool = False, + ) -> tuple[Index, h5py.Dataset]: + """ + Returns a numpy array for a given channel name. + + Args: + channel (str): The name of the channel. + slice_ (bool): If True, applies slicing on the dataset. + + Returns: + tuple[Index, h5py.Dataset]: A tuple containing the train ID Index and the numpy array + for the channel's data. + """ + # Get the data from the necessary h5 file and channel + index_key = self._config.channels.get(channel).index_key + dataset_key = self._config.channels.get(channel).dataset_key + + key = Index(self.h5_file[index_key], name="trainId") # macrobunch + dataset = self.h5_file[dataset_key] + + if slice_: + slice_index = self._config.channels.get(channel).slice + if slice_index is not None: + dataset = np.take(dataset, slice_index, axis=1) + # If np_array is size zero, fill with NaNs + if dataset.shape[0] == 0: + # Fill the np_array with NaN values of the same shape as train_id + dataset = np.full_like(key, np.nan, dtype=np.double) + + return key, dataset + + @property + def df_electron(self) -> DataFrame: + """ + Returns a pandas DataFrame for a given channel name of type [per electron]. + + Returns: + DataFrame: The pandas DataFrame for the 'per_electron' channel's data. + """ + raise NotImplementedError("This method must be implemented in a child class.") + + @property + def df_pulse(self) -> DataFrame: + """ + Returns a pandas DataFrame for a given channel name of type [per pulse]. + + Returns: + DataFrame: The pandas DataFrame for the 'per_pulse' channel's data. + """ + series = [] + channels = get_channels(self._config.channels, "per_pulse") + if not channels: + return DataFrame() + + for channel in channels: + # get slice + key, dataset = self.get_dataset_array(channel, slice_=True) + index = MultiIndex.from_product( + (key, np.arange(0, dataset.shape[1]), [0]), + names=self.multi_index, + ) + series.append(Series(dataset[()].ravel(), index=index, name=channel)) + + return concat(series, axis=1) # much faster when concatenating similarly indexed data first + + @property + def df_train(self) -> DataFrame: + """ + Returns a pandas DataFrame for a given channel name of type [per train]. + + Returns: + DataFrame: The pandas DataFrame for the 'per_train' channel's data. + """ + series = [] + channels = get_channels(self._config.channels, "per_train") + if not channels: + return DataFrame() + + for channel in channels: + key, dataset = self.get_dataset_array(channel, slice_=True) + index = MultiIndex.from_product( + (key, [0], [0]), + names=self.multi_index, + ) + if channel == "dldAux": + aux_channels = self._config.channels["dldAux"].dldAuxChannels + for aux_ch_name in aux_channels: + aux_ch = aux_channels[aux_ch_name] + series.append( + Series(dataset[: key.size, aux_ch.slice], index, name=aux_ch.name), + ) + else: + series.append(Series(dataset, index, name=channel)) + + return concat(series, axis=1) + + def validate_channel_keys(self) -> None: + """ + Validates if the index and dataset keys for all channels in config exist in the h5 file. + + Raises: + KeyError: If the index or dataset keys do not exist in the file. + """ + for channel in self._config.channels: + index_key = self._config.channels.get(channel).index_key + dataset_key = self._config.channels.get(channel).dataset_key + if index_key not in self.h5_file: + raise KeyError(f"Index key '{index_key}' doesn't exist in the file.") + if dataset_key not in self.h5_file: + raise KeyError(f"Dataset key '{dataset_key}' doesn't exist in the file.") + + @property + def df(self) -> DataFrame: + """ + Joins the 'per_electron', 'per_pulse', and 'per_train' using join operation, + returning a single dataframe. + + Returns: + DataFrame: The combined pandas DataFrame. + """ + + self.validate_channel_keys() + + dfs_to_join = (self.df_electron, self.df_pulse, self.df_train) + + def conditional_join(left_df, right_df): + """Performs conditional join of two dataframes. + Logic: if both dataframes are empty, return empty dataframe. + If one of the dataframes is empty, return the other dataframe. + Otherwise, perform outer join on multiindex of the dataframes.""" + return ( + DataFrame() + if left_df.empty and right_df.empty + else right_df + if left_df.empty + else left_df + if right_df.empty + else left_df.join(right_df, on=self.multi_index, how="outer") + ) + + # Perform conditional join for each combination of dataframes using reduce + joined_df = reduce(conditional_join, dfs_to_join) + + return joined_df.sort_index() diff --git a/sed/loader/fel/parquet.py b/sed/loader/fel/parquet.py new file mode 100644 index 00000000..9ad19e04 --- /dev/null +++ b/sed/loader/fel/parquet.py @@ -0,0 +1,121 @@ +""" +The ParquetHandler class allows for saving and reading Dask DataFrames to/from Parquet files. +It also provides methods for initializing paths, saving Parquet files and also reading them +into a Dask DataFrame. + +Typical usage example: + + parquet_handler = ParquetHandler(parquet_names='data', folder=Path('/path/to/folder')) + parquet_handler.save_parquet(df) # df is a uncomputed Dask DataFrame + data = parquet_handler.read_parquet() +""" +from __future__ import annotations + +from pathlib import Path + +import dask.dataframe as ddf + + +class ParquetHandler: + """A class for handling the creation and manipulation of Parquet files.""" + + def __init__( + self, + parquet_names: str | list[str] = None, + folder: Path = None, + subfolder: str = "", + prefix: str = "", + suffix: str = "", + extension: str = "parquet", + parquet_paths: Path = None, + ): + """ + A handler for saving and reading Dask DataFrames to/from Parquet files. + + Args: + parquet_names Union[str, List[str]]: The base name of the Parquet files. + folder (Path): The directory where the Parquet file will be stored. + subfolder (str): Optional subfolder within the main folder. + prefix (str): Optional prefix for the Parquet file name. + suffix (str): Optional suffix for the Parquet file name. + parquet_path (Path): Optional custom path for the Parquet file. + """ + + self.parquet_paths: list[Path] = None + + if isinstance(parquet_names, str): + parquet_names = [parquet_names] + + if not folder and not parquet_paths: + raise ValueError("Please provide folder or parquet_paths.") + if folder and not parquet_names: + raise ValueError("With folder, please provide parquet_names.") + + # If parquet_paths is provided, use it and ignore the other arguments + # Else, initialize the paths + if parquet_paths: + self.parquet_paths = ( + parquet_paths if isinstance(parquet_paths, list) else [parquet_paths] + ) + else: + self._initialize_paths(parquet_names, folder, subfolder, prefix, suffix, extension) + + def _initialize_paths( + self, + parquet_names: list[str], + folder: Path, + subfolder: str = None, + prefix: str = None, + suffix: str = None, + extension: str = None, + ) -> None: + """ + Create the directory for the Parquet file. + """ + # Create the full path for the Parquet file + parquet_dir = folder.joinpath(subfolder) + parquet_dir.mkdir(parents=True, exist_ok=True) + + if extension: + extension = f".{extension}" # to be backwards compatible + self.parquet_paths = [ + parquet_dir.joinpath(Path(f"{prefix}{name}{suffix}{extension}")) + for name in parquet_names + ] + + def save_parquet( + self, + dfs: ddf.DataFrame | list[ddf.DataFrame], + drop_index: bool = False, + ) -> None: + """ + Save the DataFrame to a Parquet file. + + Args: + dfs (DataFrame | ddf.DataFrame): The pandas or Dask Dataframe to be saved. + drop_index (bool): If True, drops the index before saving. + """ + # Compute the Dask DataFrame, reset the index, and save to Parquet + dfs = dfs if isinstance(dfs, list) else [dfs] + for df, parquet_path in zip(dfs, self.parquet_paths): + df.compute().reset_index(drop=drop_index).to_parquet(parquet_path) + + def read_parquet(self) -> list[ddf.DataFrame]: + """ + Read a Dask DataFrame from the Parquet file. + + Returns: + ddf.DataFrame: The Dask DataFrame read from the Parquet file. + + Raises: + FileNotFoundError: If the Parquet file does not exist. + """ + dfs = [] + for parquet_path in self.parquet_paths: + if not parquet_path.exists(): + raise FileNotFoundError( + "The Parquet file does not exist. " + "If it is in another location, provide the correct path as parquet_path.", + ) + dfs.append(ddf.read_parquet(parquet_path)) + return dfs diff --git a/sed/loader/fel/utils.py b/sed/loader/fel/utils.py new file mode 100644 index 00000000..2825fd31 --- /dev/null +++ b/sed/loader/fel/utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +MULTI_INDEX = ["trainId", "pulseId", "electronId"] +DLD_AUX_ALIAS = "dldAux" +DLDAUX_CHANNELS = "dldAuxChannels" +FORMATS = ["per_electron", "per_pulse", "per_train"] + + +def get_channels( + channels_dict: dict = None, + formats: str | list[str] = None, + index: bool = False, + extend_aux: bool = False, + remove_index_from_format: bool = True, +) -> list[str]: + """ + Returns a list of channels associated with the specified format(s). + + Args: + channels_dict (dict): The channels dictionary. + formats (Union[str, List[str]]): The desired format(s) + ('per_pulse', 'per_electron', 'per_train', 'all'). + index (bool): If True, includes channels from the multiindex. + extend_aux (bool): If True, includes channels from the 'dldAuxChannels' dictionary, + else includes 'dldAux'. + + Returns: + List[str]: A list of channels with the specified format(s). + """ + # If 'formats' is a single string, convert it to a list for uniform processing. + if isinstance(formats, str): + formats = [formats] + + # If 'formats' is a string "all", gather all possible formats. + if formats == ["all"]: + channels = get_channels( + channels_dict, + FORMATS, + index, + extend_aux, + ) + return channels + + channels = [] + + # Include channels from multi_index if 'index' is True. + if index: + channels.extend(MULTI_INDEX) + + if formats: + # If 'formats' is a list, check if all elements are valid. + for format_ in formats: + if format_ not in FORMATS + ["all"]: + raise ValueError( + "Invalid format. Please choose from 'per_electron', 'per_pulse',\ + 'per_train', 'all'.", + ) + + # Get the available channels excluding the index + available_channels = list(channels_dict.keys()) + for ch in MULTI_INDEX: + if remove_index_from_format and ch in available_channels: + available_channels.remove(ch) + + for format_ in formats: + # Gather channels based on the specified format(s). + channels.extend( + key + for key in available_channels + if channels_dict[key].format == format_ and key != DLD_AUX_ALIAS + ) + # Include 'dldAuxChannels' if the format is 'per_pulse' and extend_aux is True. + # Otherwise, include 'dldAux'. + if format_ == FORMATS[2] and DLD_AUX_ALIAS in available_channels: + if extend_aux: + channels.extend( + channels_dict.get(DLD_AUX_ALIAS).dldAuxChannels.keys(), + ) + else: + channels.extend([DLD_AUX_ALIAS]) + + return channels diff --git a/sed/loader/flash/dataframe.py b/sed/loader/flash/dataframe.py new file mode 100644 index 00000000..588a811d --- /dev/null +++ b/sed/loader/flash/dataframe.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import numpy as np +from pandas import concat +from pandas import DataFrame +from pandas import MultiIndex +from pandas import Series + +from sed.loader.fel.dataframe import DataFrameCreator +from sed.loader.fel.utils import get_channels + + +class FlashDataFrameCreator(DataFrameCreator): + def pulse_index(self, offset: int) -> tuple[MultiIndex, slice | np.ndarray]: + """ + Computes the index for the 'per_electron' data. + + Args: + offset (int): The offset value. + + Returns: + tuple[MultiIndex, np.ndarray]: A tuple containing the computed MultiIndex and + the indexer. + """ + # Get the pulseId and the index_train + index_train, dataset_pulse = self.get_dataset_array("pulseId", slice_=True) + # Repeat the index_train by the number of pulses + index_train_repeat = np.repeat(index_train, dataset_pulse.shape[1]) + # Explode the pulse dataset and subtract by the ubid_offset + pulse_ravel = dataset_pulse.ravel() - offset + # Create a MultiIndex with the index_train and the pulse + microbunches = MultiIndex.from_arrays((index_train_repeat, pulse_ravel)).dropna() + + # Only sort if necessary + indexer = slice(None) + if not microbunches.is_monotonic_increasing: + microbunches, indexer = microbunches.sort_values(return_indexer=True) + + # Count the number of electrons per microbunch and create an array of electrons + electron_counts = microbunches.value_counts(sort=False).values + electrons = np.concatenate([np.arange(count) for count in electron_counts]) + + # Final index constructed here + index = MultiIndex.from_arrays( + ( + microbunches.get_level_values(0), + microbunches.get_level_values(1).astype(int), + electrons, + ), + names=self.multi_index, + ) + return index, indexer + + @property + def df_electron(self) -> DataFrame: + """ + Returns a pandas DataFrame for a given channel name of type [per electron]. + + Returns: + DataFrame: The pandas DataFrame for the 'per_electron' channel's data. + """ + channels = get_channels(self._config.channels, "per_electron") + if not channels: + return DataFrame() + + offset = self._config.ubid_offset + # Index + index, indexer = self.pulse_index(offset) + + # Data logic + slice_index = [self._config.channels.get(channel).slice for channel in channels] + + # First checking if dataset keys are the same for all channels + dataset_keys = [self._config.channels.get(channel).dataset_key for channel in channels] + all_keys_same = all(key == dataset_keys[0] for key in dataset_keys) + + # If all dataset keys are the same, we can directly use the ndarray to create frame + if all_keys_same: + _, dataset = self.get_dataset_array(channels[0]) + data_dict = { + channel: dataset[:, slice_, :].ravel() + for channel, slice_ in zip(channels, slice_index) + } + dataframe = DataFrame(data_dict) + # Otherwise, we need to create a Series for each channel and concatenate them + else: + series = { + channel: Series(self.get_dataset_array(channel, slice_=True)[1].ravel()) + for channel in channels + } + dataframe = concat(series, axis=1) + + drop_vals = np.arange(-offset, 0) + + # Few things happen here: + # Drop all NaN values like while creating the multiindex + # if necessary, the data is sorted with [indexer] + # MultiIndex is set + # Finally, the offset values are dropped + return ( + dataframe.dropna()[indexer] + .set_index(index) + .drop(index=drop_vals, level="pulseId", errors="ignore") + ) diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 6c0b8d1b..7e2b2f1b 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -2,35 +2,26 @@ This module implements the flash data loader. This loader currently supports hextof, wespe and instruments with similar structure. The raw hdf5 data is combined and saved into buffer files and loaded as a dask dataframe. -The dataframe is a amalgamation of all h5 files for a combination of runs, where the NaNs are -automatically forward filled across different files. +The dataframe is an amalgamation of all h5 files for a combination of runs, where the NaNs are +automatically forward-filled across different files. This can then be saved as a parquet for out-of-sed processing and reread back to access other -sed funtionality. +sed functionality. """ +from __future__ import annotations + import time -from functools import reduce from pathlib import Path -from typing import List from typing import Sequence -from typing import Tuple -from typing import Union import dask.dataframe as dd -import h5py -import numpy as np -import pyarrow.parquet as pq -from joblib import delayed -from joblib import Parallel from natsort import natsorted -from pandas import DataFrame -from pandas import MultiIndex -from pandas import Series -from sed.core import dfops from sed.loader.base.loader import BaseLoader +from sed.loader.fel import BufferHandler +from sed.loader.fel import ParquetHandler +from sed.loader.fel.config_model import LoaderConfig +from sed.loader.flash.dataframe import FlashDataFrameCreator from sed.loader.flash.metadata import MetadataRetriever -from sed.loader.utils import parse_h5_keys -from sed.loader.utils import split_dld_time_from_sector_id class FlashLoader(BaseLoader): @@ -45,80 +36,24 @@ class FlashLoader(BaseLoader): supported_file_types = ["h5"] def __init__(self, config: dict) -> None: - super().__init__(config=config) - self.multi_index = ["trainId", "pulseId", "electronId"] - self.index_per_electron: MultiIndex = None - self.index_per_pulse: MultiIndex = None - self.failed_files_error: List[str] = [] - - def initialize_paths(self) -> Tuple[List[Path], Path]: """ - Initializes the paths based on the configuration. - - Returns: - Tuple[List[Path], Path]: A tuple containing a list of raw data directories - paths and the parquet data directory path. + Initializes the FlashLoader. - Raises: - ValueError: If required values are missing from the configuration. - FileNotFoundError: If the raw data directories are not found. + Args: + config (dict): The configuration dictionary or model. """ - # Parses to locate the raw beamtime directory from config file - if "paths" in self._config["core"]: - data_raw_dir = [ - Path(self._config["core"]["paths"].get("data_raw_dir", "")), - ] - data_parquet_dir = Path( - self._config["core"]["paths"].get("data_parquet_dir", ""), - ) - - else: - try: - beamtime_id = self._config["core"]["beamtime_id"] - year = self._config["core"]["year"] - daq = self._config["dataframe"]["daq"] - except KeyError as exc: - raise ValueError( - "The beamtime_id, year and daq are required.", - ) from exc - - beamtime_dir = Path( - self._config["dataframe"]["beamtime_dir"][self._config["core"]["beamline"]], - ) - beamtime_dir = beamtime_dir.joinpath(f"{year}/data/{beamtime_id}/") - - # Use pathlib walk to reach the raw data directory - data_raw_dir = [] - raw_path = beamtime_dir.joinpath("raw") - - for path in raw_path.glob("**/*"): - if path.is_dir(): - dir_name = path.name - if dir_name.startswith("express-") or dir_name.startswith( - "online-", - ): - data_raw_dir.append(path.joinpath(daq)) - elif dir_name == daq.upper(): - data_raw_dir.append(path) - - if not data_raw_dir: - raise FileNotFoundError("Raw data directories not found.") - - parquet_path = "processed/parquet" - data_parquet_dir = beamtime_dir.joinpath(parquet_path) - - data_parquet_dir.mkdir(parents=True, exist_ok=True) - - return data_raw_dir, data_parquet_dir + super().__init__(config=config) + self.config = LoaderConfig(**self._config) def get_files_from_run_id( self, run_id: str, - folders: Union[str, Sequence[str]] = None, + folders: str | Sequence[str] = None, extension: str = "h5", **kwds, - ) -> List[str]: - """Returns a list of filenames for a given run located in the specified directory + ) -> list[str]: + """ + Returns a list of filenames for a given run located in the specified directory for the specified data acquisition (daq). Args: @@ -126,8 +61,6 @@ def get_files_from_run_id( folders (Union[str, Sequence[str]], optional): The directory(ies) where the raw data is located. Defaults to config["core"]["base_folder"]. extension (str, optional): The file extension. Defaults to "h5". - kwds: Keyword arguments: - - daq (str): The data acquisition identifier. Returns: List[str]: A list of path strings representing the collected file names. @@ -136,20 +69,18 @@ def get_files_from_run_id( FileNotFoundError: If no files are found for the given run in the directory. """ # Define the stream name prefixes based on the data acquisition identifier - stream_name_prefixes = self._config["dataframe"]["stream_name_prefixes"] + stream_name_prefix = self.config.dataframe.stream_name_prefix if folders is None: - folders = self._config["core"]["base_folder"] + folders = self.config.core.base_folder if isinstance(folders, str): folders = [folders] - daq = kwds.pop("daq", self._config.get("dataframe", {}).get("daq")) - # Generate the file patterns to search for in the directory - file_pattern = f"{stream_name_prefixes[daq]}_run{run_id}_*." + extension + file_pattern = f"{stream_name_prefix}_run{run_id}_*." + extension - files: List[Path] = [] + files: list[Path] = [] # Use pathlib to search for matching files in each directory for folder in folders: files.extend( @@ -168,682 +99,18 @@ def get_files_from_run_id( # Return the list of found files return [str(file.resolve()) for file in files] - @property - def available_channels(self) -> List: - """Returns the channel names that are available for use, - excluding pulseId, defined by the json file""" - available_channels = list(self._config["dataframe"]["channels"].keys()) - available_channels.remove("pulseId") - return available_channels - - def get_channels(self, formats: Union[str, List[str]] = "", index: bool = False) -> List[str]: - """ - Returns a list of channels associated with the specified format(s). - - Args: - formats (Union[str, List[str]]): The desired format(s) - ('per_pulse', 'per_electron', 'per_train', 'all'). - index (bool): If True, includes channels from the multi_index. - - Returns: - List[str]: A list of channels with the specified format(s). - """ - # If 'formats' is a single string, convert it to a list for uniform processing. - if isinstance(formats, str): - formats = [formats] - - # If 'formats' is a string "all", gather all possible formats. - if formats == ["all"]: - channels = self.get_channels(["per_pulse", "per_train", "per_electron"], index) - return channels - - channels = [] - for format_ in formats: - # Gather channels based on the specified format(s). - channels.extend( - key - for key in self.available_channels - if self._config["dataframe"]["channels"][key]["format"] == format_ - and key != "dldAux" - ) - # Include 'dldAuxChannels' if the format is 'per_pulse'. - if format_ == "per_pulse": - channels.extend( - self._config["dataframe"]["channels"]["dldAux"]["dldAuxChannels"].keys(), - ) - - # Include channels from multi_index if 'index' is True. - if index: - channels.extend(self.multi_index) - - return channels - - def reset_multi_index(self) -> None: - """Resets the index per pulse and electron""" - self.index_per_electron = None - self.index_per_pulse = None - - def create_multi_index_per_electron(self, h5_file: h5py.File) -> None: - """ - Creates an index per electron using pulseId for usage with the electron - resolved pandas DataFrame. - - Args: - h5_file (h5py.File): The HDF5 file object. - - Notes: - - This method relies on the 'pulseId' channel to determine - the macrobunch IDs. - - It creates a MultiIndex with trainId, pulseId, and electronId - as the index levels. - """ - - # Macrobunch IDs obtained from the pulseId channel - [train_id, np_array] = self.create_numpy_array_per_channel( - h5_file, - "pulseId", - ) - - # Create a series with the macrobunches as index and - # microbunches as values - macrobunches = ( - Series( - (np_array[i] for i in train_id.index), - name="pulseId", - index=train_id, - ) - - self._config["dataframe"]["ubid_offset"] - ) - - # Explode dataframe to get all microbunch vales per macrobunch, - # remove NaN values and convert to type int - microbunches = macrobunches.explode().dropna().astype(int) - - # Create temporary index values - index_temp = MultiIndex.from_arrays( - (microbunches.index, microbunches.values), - names=["trainId", "pulseId"], - ) - - # Calculate the electron counts per pulseId unique preserves the order of appearance - electron_counts = index_temp.value_counts()[index_temp.unique()].values - - # Series object for indexing with electrons - electrons = ( - Series( - [np.arange(electron_counts[i]) for i in range(electron_counts.size)], - ) - .explode() - .astype(int) - ) - - # Create a pandas MultiIndex using the exploded datasets - self.index_per_electron = MultiIndex.from_arrays( - (microbunches.index, microbunches.values, electrons), - names=self.multi_index, - ) - - def create_multi_index_per_pulse( - self, - train_id: Series, - np_array: np.ndarray, - ) -> None: - """ - Creates an index per pulse using a pulse resolved channel's macrobunch ID, for usage with - the pulse resolved pandas DataFrame. - - Args: - train_id (Series): The train ID Series. - np_array (np.ndarray): The numpy array containing the pulse resolved data. - - Notes: - - This method creates a MultiIndex with trainId and pulseId as the index levels. - """ - - # Create a pandas MultiIndex, useful for comparing electron and - # pulse resolved dataframes - self.index_per_pulse = MultiIndex.from_product( - (train_id, np.arange(0, np_array.shape[1])), - names=["trainId", "pulseId"], - ) - - def create_numpy_array_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> Tuple[Series, np.ndarray]: - """ - Returns a numpy array for a given channel name for a given file. - - Args: - h5_file (h5py.File): The h5py file object. - channel (str): The name of the channel. - - Returns: - Tuple[Series, np.ndarray]: A tuple containing the train ID Series and the numpy array - for the channel's data. - - """ - # Get the data from the necessary h5 file and channel - group = h5_file[self._config["dataframe"]["channels"][channel]["group_name"]] - - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - train_id = Series(group["index"], name="trainId") # macrobunch - - # unpacks the timeStamp or value - if channel == "timeStamp": - np_array = group["time"][()] - else: - np_array = group["value"][()] - - # Use predefined axis and slice from the json file - # to choose correct dimension for necessary channel - if "slice" in channel_dict: - np_array = np.take( - np_array, - channel_dict["slice"], - axis=1, - ) - return train_id, np_array - - def create_dataframe_per_electron( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per electron]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - The microbunch resolved data is exploded and converted to a DataFrame. The MultiIndex - is set, and the NaN values are dropped, alongside the pulseId = 0 (meaningless). - - """ - return ( - Series((np_array[i] for i in train_id.index), name=channel) - .explode() - .dropna() - .to_frame() - .set_index(self.index_per_electron) - .drop( - index=np.arange(-self._config["dataframe"]["ubid_offset"], 0), - level=1, - errors="ignore", - ) - ) - - def create_dataframe_per_pulse( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - channel_dict: dict, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per pulse]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - channel_dict (dict): The dictionary containing channel parameters. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - - For auxillary channels, the macrobunch resolved data is repeated 499 times to be - compared to electron resolved data for each auxillary channel. The data is then - converted to a multicolumn DataFrame. - - For all other pulse resolved channels, the macrobunch resolved data is exploded - to a DataFrame and the MultiIndex is set. - - """ - - # Special case for auxillary channels - if channel == "dldAux": - # Checks the channel dictionary for correct slices and creates a multicolumn DataFrame - data_frames = ( - Series( - (np_array[i, value] for i in train_id.index), - name=key, - index=train_id, - ).to_frame() - for key, value in channel_dict["dldAuxChannels"].items() - ) - - # Multiindex set and combined dataframe returned - data = reduce(DataFrame.combine_first, data_frames) - - # For all other pulse resolved channels - else: - # Macrobunch resolved data is exploded to a DataFrame and the MultiIndex is set - - # Creates the index_per_pulse for the given channel - self.create_multi_index_per_pulse(train_id, np_array) - data = ( - Series((np_array[i] for i in train_id.index), name=channel) - .explode() - .to_frame() - .set_index(self.index_per_pulse) - ) - - return data - - def create_dataframe_per_train( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per train]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - """ - return ( - Series((np_array[i] for i in train_id.index), name=channel) - .to_frame() - .set_index(train_id) - ) - - def create_dataframe_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> Union[Series, DataFrame]: - """ - Returns a pandas DataFrame for a given channel name from a given file. - - This method takes an h5py.File object `h5_file` and a channel name `channel`, and returns - a pandas DataFrame containing the data for that channel from the file. The format of the - DataFrame depends on the channel's format specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. - channel (str): The name of the channel. - - Returns: - Union[Series, DataFrame]: A pandas Series or DataFrame representing the channel's data. - - Raises: - ValueError: If the channel has an undefined format. - - """ - [train_id, np_array] = self.create_numpy_array_per_channel( - h5_file, - channel, - ) # numpy Array created - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - # If np_array is size zero, fill with NaNs - if np_array.size == 0: - # Fill the np_array with NaN values of the same shape as train_id - np_array = np.full_like(train_id, np.nan, dtype=np.double) - # Create a Series using np_array, with train_id as the index - data = Series( - (np_array[i] for i in train_id.index), - name=channel, - index=train_id, - ) - - # Electron resolved data is treated here - if channel_dict["format"] == "per_electron": - # If index_per_electron is None, create it for the given file - if self.index_per_electron is None: - self.create_multi_index_per_electron(h5_file) - - # Create a DataFrame for electron-resolved data - data = self.create_dataframe_per_electron( - np_array, - train_id, - channel, - ) - - # Pulse resolved data is treated here - elif channel_dict["format"] == "per_pulse": - # Create a DataFrame for pulse-resolved data - data = self.create_dataframe_per_pulse( - np_array, - train_id, - channel, - channel_dict, - ) - - # Train resolved data is treated here - elif channel_dict["format"] == "per_train": - # Create a DataFrame for train-resolved data - data = self.create_dataframe_per_train(np_array, train_id, channel) - - else: - raise ValueError( - channel - + "has an undefined format. Available formats are \ - per_pulse, per_electron and per_train", - ) - - return data - - def concatenate_channels( - self, - h5_file: h5py.File, - ) -> DataFrame: - """ - Concatenates the channels from the provided h5py.File into a pandas DataFrame. - - This method takes an h5py.File object `h5_file` and concatenates the channels present in - the file into a single pandas DataFrame. The concatenation is performed based on the - available channels specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. - - Returns: - DataFrame: A concatenated pandas DataFrame containing the channels. - - Raises: - ValueError: If the group_name for any channel does not exist in the file. - - """ - all_keys = parse_h5_keys(h5_file) # Parses all channels present - - # Check for if the provided group_name actually exists in the file - for channel in self._config["dataframe"]["channels"]: - if channel == "timeStamp": - group_name = self._config["dataframe"]["channels"][channel]["group_name"] + "time" - else: - group_name = self._config["dataframe"]["channels"][channel]["group_name"] + "value" - - if group_name not in all_keys: - raise ValueError( - f"The group_name for channel {channel} does not exist.", - ) - - # Create a generator expression to generate data frames for each channel - data_frames = ( - self.create_dataframe_per_channel(h5_file, each) for each in self.available_channels - ) - - # Use the reduce function to join the data frames into a single DataFrame - return reduce( - lambda left, right: left.join(right, how="outer"), - data_frames, - ) - - def create_dataframe_per_file( - self, - file_path: Path, - ) -> DataFrame: - """ - Create pandas DataFrames for the given file. - - This method loads an HDF5 file specified by `file_path` and constructs a pandas DataFrame - from the datasets within the file. The order of datasets in the DataFrames is the opposite - of the order specified by channel names. - - Args: - file_path (Path): Path to the input HDF5 file. - - Returns: - DataFrame: pandas DataFrame - - """ - # Loads h5 file and creates a dataframe - with h5py.File(file_path, "r") as h5_file: - self.reset_multi_index() # Reset MultiIndexes for next file - df = self.concatenate_channels(h5_file) - df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps")) - # correct the 3 bit shift which encodes the detector ID in the 8s time - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - df = split_dld_time_from_sector_id(df, config=self._config) - return df - - def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> Union[bool, Exception]: - """ - Converts an HDF5 file to Parquet format to create a buffer file. - - This method uses `create_dataframe_per_file` method to create dataframes from individual - files within an HDF5 file. The resulting dataframe is then saved to a Parquet file. - - Args: - h5_path (Path): Path to the input HDF5 file. - parquet_path (Path): Path to the output Parquet file. - - Raises: - ValueError: If an error occurs during the conversion process. - - """ - try: - ( - self.create_dataframe_per_file(h5_path) - .reset_index(level=self.multi_index) - .to_parquet(parquet_path, index=False) - ) - except Exception as exc: # pylint: disable=broad-except - self.failed_files_error.append(f"{parquet_path}: {type(exc)} {exc}") - return exc - return None - - def buffer_file_handler( - self, - data_parquet_dir: Path, - detector: str, - force_recreate: bool, - ) -> Tuple[List[Path], List, List]: - """ - Handles the conversion of buffer files (h5 to parquet) and returns the filenames. - - Args: - data_parquet_dir (Path): Directory where the parquet files will be stored. - detector (str): Detector name. - force_recreate (bool): Forces recreation of buffer files - - Returns: - Tuple[List[Path], List, List]: Three lists, one for - parquet file paths, one for metadata and one for schema. - - Raises: - FileNotFoundError: If the conversion fails for any files or no data is available. - """ - - # Create the directory for buffer parquet files - buffer_file_dir = data_parquet_dir.joinpath("buffer") - buffer_file_dir.mkdir(parents=True, exist_ok=True) - - # Create two separate lists for h5 and parquet file paths - h5_filenames = [Path(file) for file in self.files] - parquet_filenames = [ - buffer_file_dir.joinpath(Path(file).stem + detector) for file in self.files - ] - existing_parquet_filenames = [file for file in parquet_filenames if file.exists()] - - # Raise a value error if no data is available after the conversion - if len(h5_filenames) == 0: - raise ValueError("No data available. Probably failed reading all h5 files") - - if not force_recreate: - # Check if the available channels match the schema of the existing parquet files - parquet_schemas = [pq.read_schema(file) for file in existing_parquet_filenames] - config_schema = set(self.get_channels(formats="all", index=True)) - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - config_schema.add(self._config["dataframe"].get("sector_id_column", False)) - - for i, schema in enumerate(parquet_schemas): - schema_set = set(schema.names) - if schema_set != config_schema: - missing_in_parquet = config_schema - schema_set - missing_in_config = schema_set - config_schema - - missing_in_parquet_str = ( - f"Missing in parquet: {missing_in_parquet}" if missing_in_parquet else "" - ) - missing_in_config_str = ( - f"Missing in config: {missing_in_config}" if missing_in_config else "" - ) - - raise ValueError( - "The available channels do not match the schema of file", - f"{existing_parquet_filenames[i]}", - f"{missing_in_parquet_str}", - f"{missing_in_config_str}", - "Please check the configuration file or set force_recreate to True.", - ) - - # Choose files to read - files_to_read = [ - (h5_path, parquet_path) - for h5_path, parquet_path in zip(h5_filenames, parquet_filenames) - if force_recreate or not parquet_path.exists() - ] - - print(f"Reading files: {len(files_to_read)} new files of {len(h5_filenames)} total.") - - # Initialize the indices for create_buffer_file conversion - self.reset_multi_index() - - # Convert the remaining h5 files to parquet in parallel if there are any - if len(files_to_read) > 0: - error = Parallel(n_jobs=len(files_to_read), verbose=10)( - delayed(self.create_buffer_file)(h5_path, parquet_path) - for h5_path, parquet_path in files_to_read - ) - if any(error): - raise RuntimeError(f"Conversion failed for some files. {error}") - - # Raise an error if the conversion failed for any files - # TODO: merge this and the previous error trackings - if self.failed_files_error: - raise FileNotFoundError( - "Conversion failed for the following files:\n" + "\n".join(self.failed_files_error), - ) - - print("All files converted successfully!") - - # read all parquet metadata and schema - metadata = [pq.read_metadata(file) for file in parquet_filenames] - schema = [pq.read_schema(file) for file in parquet_filenames] - - return parquet_filenames, metadata, schema - - def parquet_handler( - self, - data_parquet_dir: Path, - detector: str = "", - parquet_path: Path = None, - converted: bool = False, - load_parquet: bool = False, - save_parquet: bool = False, - force_recreate: bool = False, - ) -> Tuple[dd.DataFrame, dd.DataFrame]: - """ - Handles loading and saving of parquet files based on the provided parameters. - - Args: - data_parquet_dir (Path): Directory where the parquet files are located. - detector (str, optional): Adds a identifier for parquets to distinguish multidetector - systems. - parquet_path (str, optional): Path to the combined parquet file. - converted (bool, optional): True if data is augmented by adding additional columns - externally and saved into converted folder. - load_parquet (bool, optional): Loads the entire parquet into the dd dataframe. - save_parquet (bool, optional): Saves the entire dataframe into a parquet. - force_recreate (bool, optional): Forces recreation of buffer file. - Returns: - tuple: A tuple containing two dataframes: - - dataframe_electron: Dataframe containing the loaded/augmented electron data. - - dataframe_pulse: Dataframe containing the loaded/augmented timed data. - - Raises: - FileNotFoundError: If the requested parquet file is not found. - - """ - - # Construct the parquet path if not provided - if parquet_path is None: - parquet_name = "_".join(str(run) for run in self.runs) - parquet_dir = data_parquet_dir.joinpath("converted") if converted else data_parquet_dir - - parquet_path = parquet_dir.joinpath( - "run_" + parquet_name + detector, - ).with_suffix(".parquet") - - # Check if load_parquet is flagged and then load the file if it exists - if load_parquet: - try: - dataframe = dd.read_parquet(parquet_path) - except Exception as exc: - raise FileNotFoundError( - "The final parquet for this run(s) does not exist yet. " - "If it is in another location, please provide the path as parquet_path.", - ) from exc - - else: - # Obtain the parquet filenames, metadata and schema from the method - # which handles buffer file creation/reading - filenames, metadata, _ = self.buffer_file_handler( - data_parquet_dir, - detector, - force_recreate, - ) - - # Read all parquet files into one dataframe using dask - dataframe = dd.read_parquet(filenames, calculate_divisions=True) - - # Channels to fill NaN values - channels: List[str] = self.get_channels(["per_pulse", "per_train"]) - - overlap = min(file.num_rows for file in metadata) - - print("Filling nan values...") - dataframe = dfops.forward_fill_lazy( - df=dataframe, - columns=channels, - before=overlap, - iterations=self._config["dataframe"].get("forward_fill_iterations", 2), - ) - # Remove the NaNs from per_electron channels - dataframe_electron = dataframe.dropna( - subset=self.get_channels(["per_electron"]), - ) - dataframe_pulse = dataframe[ - self.multi_index + self.get_channels(["per_pulse", "per_train"]) - ] - dataframe_pulse = dataframe_pulse[ - (dataframe_pulse["electronId"] == 0) | (np.isnan(dataframe_pulse["electronId"])) - ] - - # Save the dataframe as parquet if requested - if save_parquet: - dataframe_electron.compute().reset_index(drop=True).to_parquet(parquet_path) - print("Combined parquet file saved.") - - return dataframe_electron, dataframe_pulse - def parse_metadata(self) -> dict: """Uses the MetadataRetriever class to fetch metadata from scicat for each run. Returns: dict: Metadata dictionary """ - metadata_retriever = MetadataRetriever(self._config["metadata"]) + # check if beamtime_id is set + if self.config.core.beamtime_id is None: + raise ValueError("Beamtime ID is required to fetch metadata.") + metadata_retriever = MetadataRetriever(self.config.metadata) metadata = metadata_retriever.get_metadata( - beamtime_id=self._config["core"]["beamtime_id"], + beamtime_id=self.config.core.beamtime_id, runs=self.runs, metadata=self.metadata, ) @@ -862,14 +129,20 @@ def get_elapsed_time(self, fids=None, **kwds): def read_dataframe( self, - files: Union[str, Sequence[str]] = None, - folders: Union[str, Sequence[str]] = None, - runs: Union[str, Sequence[str]] = None, + files: str | Sequence[str] = None, + folders: str | Sequence[str] = None, + runs: str | Sequence[str] = None, ftype: str = "h5", metadata: dict = None, collect_metadata: bool = False, + converted: bool = False, + load_parquet: bool = False, + save_parquet: bool = False, + detector: str = "", + force_recreate: bool = False, + debug: bool = False, **kwds, - ) -> Tuple[dd.DataFrame, dd.DataFrame, dict]: + ) -> tuple[dd.DataFrame, dd.DataFrame, dict]: """ Read express data from the DAQ, generating a parquet in between. @@ -879,14 +152,15 @@ def read_dataframe( Path has priority such that if it's specified, the specified files will be ignored. Defaults to None. runs (Union[str, Sequence[str]], optional): Run identifier(s). Corresponding files will - be located in the location provided by ``folders``. Takes precendence over + be located in the location provided by ``folders``. Takes precedence over ``files`` and ``folders``. Defaults to None. ftype (str, optional): The file extension type. Defaults to "h5". metadata (dict, optional): Additional metadata. Defaults to None. collect_metadata (bool, optional): Whether to collect metadata. Defaults to False. Returns: - Tuple[dd.DataFrame, dict]: A tuple containing the concatenated DataFrame and metadata. + Tuple[dd.DataFrame, dd.DataFrame, dict]: A tuple containing the concatenated DataFrame + and metadata. Raises: ValueError: If neither 'runs' nor 'files'/'data_raw_dir' is provided. @@ -894,7 +168,9 @@ def read_dataframe( """ t0 = time.time() - data_raw_dir, data_parquet_dir = self.initialize_paths() + paths = self.config.core.paths + data_raw_dir = paths.data_raw_dir + data_parquet_dir = paths.data_parquet_dir # Prepare a list of names for the runs to read and parquets to write if runs is not None: @@ -904,9 +180,9 @@ def read_dataframe( for run in runs: run_files = self.get_files_from_run_id( run_id=run, - folders=[str(folder.resolve()) for folder in data_raw_dir], + folders=[str(folder.resolve()) for folder in [data_raw_dir]], extension=ftype, - daq=self._config["dataframe"]["daq"], + daq=self.config.dataframe.daq, ) files.extend(run_files) self.runs = list(runs) @@ -922,7 +198,44 @@ def read_dataframe( metadata=metadata, ) - df, df_timed = self.parquet_handler(data_parquet_dir, **kwds) + # if parquet_dir is None, use data_parquet_dir + filename = "_".join(str(run) for run in self.runs) + converted_str = "converted" if converted else "" + # Create parquet paths for saving and loading the parquet files of df and timed_df + ph = ParquetHandler( + [filename, filename + "_timed"], + data_parquet_dir, + converted_str, + "run_", + detector, + ) + + # Check if load_parquet is flagged and then load the file if it exists + if load_parquet: + df_list = ph.read_parquet() + df = df_list[0] + df_timed = df_list[1] + + # Default behavior is to create the buffer files and load them + else: + # Obtain the parquet filenames, metadata, and schema from the method + # which handles buffer file creation/reading + h5_paths = [Path(file) for file in self.files] + buffer = BufferHandler( + FlashDataFrameCreator, + self.config.dataframe, + h5_paths, + data_parquet_dir, + force_recreate, + suffix=detector, + debug=debug, + ) + df = buffer.dataframe_electron + df_timed = buffer.dataframe_pulse + + # Save the dataframe as parquet if requested + if save_parquet: + ph.save_parquet([df, df_timed], drop_index=True) metadata = self.parse_metadata() if collect_metadata else {} print(f"loading complete in {time.time() - t0: .2f} s") diff --git a/sed/loader/flash/metadata.py b/sed/loader/flash/metadata.py index de256e4d..c62fd644 100644 --- a/sed/loader/flash/metadata.py +++ b/sed/loader/flash/metadata.py @@ -33,7 +33,7 @@ def __init__(self, metadata_config: Dict) -> None: def get_metadata( self, - beamtime_id: str, + beamtime_id: int, runs: list, metadata: Optional[Dict] = None, ) -> Dict: @@ -41,7 +41,7 @@ def get_metadata( Retrieves metadata for a given beamtime ID and list of runs. Args: - beamtime_id (str): The ID of the beamtime. + beamtime_id (int): The ID of the beamtime. runs (list): A list of run IDs. metadata (Dict, optional): The existing metadata dictionary. Defaults to None. diff --git a/sed/loader/sxp/dataframe.py b/sed/loader/sxp/dataframe.py new file mode 100644 index 00000000..677574f5 --- /dev/null +++ b/sed/loader/sxp/dataframe.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import numpy as np +from pandas import concat +from pandas import DataFrame +from pandas import Index +from pandas import MultiIndex +from pandas import Series + +from sed.loader.fel.dataframe import DataFrameCreator +from sed.loader.fel.utils import get_channels + + +class SXPDataFrameCreator(DataFrameCreator): + def get_dataset_array( + self, + channel: str, + slice_: bool = False, + ) -> tuple[Index, np.ndarray]: + """ + Returns a numpy array for a given channel name for a given file. + + Args: + h5_file (h5py.File): The h5py file object. + channel (str): The name of the channel. + + Returns: + Tuple[Index, np.ndarray]: A tuple containing the train ID Series and the numpy array + for the channel's data. + + """ + # Get the data from the necessary h5 file and channel + channel_dict = self._config.channels.get(channel) + index_key = channel_dict.index_key + dataset_key = channel_dict.dataset_key + + key = Index(self.h5_file[index_key], name="trainId") + + # unpacks the data into np.ndarray + np_array = self.h5_file[dataset_key][()] + if len(np_array.shape) == 2 and channel_dict.max_hits: + np_array = np_array[:, : channel_dict.max_hits] + + if channel_dict.scale: + np_array = np_array / float(channel_dict.scale) + + # If np_array is size zero, fill with NaNs + if len(np_array.shape) == 0: + # Fill the np_array with NaN values of the same shape as train_id + np_array = np.full_like(key, np.nan, dtype=np.double) + + return key, np_array + + def pulse_index(self) -> MultiIndex: + """ + Computes the index for the 'per_electron' data. + + Args: + offset (int): The offset value. + + Returns: + MultiIndex: computed MultiIndex + """ + train_id, mab_array = self.get_dataset_array("trainId") + train_id, mib_array = self.get_dataset_array("pulseId") + + macrobunch_index = [] + microbunch_ids = [] + macrobunch_indices = [] + + for i, _ in enumerate(train_id): + num_trains = self._config.num_trains + if num_trains: + try: + num_valid_hits = np.where(np.diff(mib_array[i].astype(np.int32)) < 0)[0][ + num_trains - 1 + ] + mab_array[i, num_valid_hits:] = 0 + mib_array[i, num_valid_hits:] = 0 + except IndexError: + pass + + train_ends = np.where(np.diff(mib_array[i].astype(np.int32)) < -1)[0] + indices = [] + index = 0 + for train, train_end in enumerate(train_ends): + macrobunch_index.append(train_id[i] + np.uint(train)) + microbunch_ids.append(mib_array[i, index:train_end]) + indices.append(slice(index, train_end)) + index = train_end + 1 + macrobunch_indices.append(indices) + + # Create a series with the macrobunches as index and + # microbunches as values + macrobunches = ( + Series( + (microbunch_ids[i] for i in range(len(macrobunch_index))), + name="pulseId", + index=macrobunch_index, + ) + - self._config.ubid_offset + ) + + # Explode dataframe to get all microbunch vales per macrobunch, + # remove NaN values and convert to type int + microbunches = macrobunches.explode().dropna().astype(int) + + # Create temporary index values + index_temp = MultiIndex.from_arrays( + (microbunches.index, microbunches.values), + names=["trainId", "pulseId"], + ) + + # Calculate the electron counts per pulseId unique preserves the order of appearance + electron_counts = index_temp.value_counts()[index_temp.unique()].values + + # Series object for indexing with electrons + electrons = ( + Series( + [np.arange(electron_counts[i]) for i in range(electron_counts.size)], + ) + .explode() + .astype(int) + ) + + # Create a pandas MultiIndex using the exploded datasets + index = MultiIndex.from_arrays( + (microbunches.index, microbunches.values, electrons), + names=self.multi_index, + ) + return macrobunch_indices, index + + @property + def df_electron(self) -> DataFrame: + """ + Returns a pandas DataFrame for 'per_electron' data. + + Returns: + DataFrame: The pandas DataFrame for the 'per_electron' data. + """ + + # Get the channels for the dataframe + channels = get_channels(self._config.channels, "per_electron") + if not channels: + return DataFrame() + + series = [] + for channel in channels: + array_indices, index = self.pulse_index() + _, np_array = self.get_dataset_array(channel) + if array_indices is None or len(array_indices) != np_array.shape[0]: + raise RuntimeError( + "macrobunch_indices not set correctly, internal inconstency detected.", + ) + train_data = [] + for i, _ in enumerate(array_indices): + for indices in array_indices[i]: + train_data.append(np_array[i, indices]) + + drop_vals = np.arange(-self._config.ubid_offset, 0) + series.append( + Series((train for train in train_data), name=channel) + .explode() + .dropna() + .to_frame() + .set_index(index) + .drop( + index=drop_vals, + level=1, + errors="ignore", + ), + ) + print(series) + + return concat(series, axis=1) diff --git a/sed/loader/sxp/loader.py b/sed/loader/sxp/loader.py index ce89e546..1be552b1 100644 --- a/sed/loader/sxp/loader.py +++ b/sed/loader/sxp/loader.py @@ -1,37 +1,27 @@ -# pylint: disable=duplicate-code """ This module implements the SXP data loader. This loader currently supports the SXP momentum microscope instrument. The raw hdf5 data is combined and saved into buffer files and loaded as a dask dataframe. -The dataframe is a amalgamation of all h5 files for a combination of runs, where the NaNs are -automatically forward filled across different files. +The dataframe is an amalgamation of all h5 files for a combination of runs, where the NaNs are +automatically forward-filled across different files. This can then be saved as a parquet for out-of-sed processing and reread back to access other -sed funtionality. -Most of the structure is identical to the FLASH loader. +sed functionality. """ +from __future__ import annotations + import time -from functools import reduce from pathlib import Path -from typing import List +from typing import cast from typing import Sequence -from typing import Tuple -from typing import Union import dask.dataframe as dd -import h5py import numpy as np -import pyarrow.parquet as pq -from joblib import delayed -from joblib import Parallel from natsort import natsorted -from pandas import DataFrame -from pandas import MultiIndex -from pandas import Series -from sed.core import dfops from sed.loader.base.loader import BaseLoader -from sed.loader.utils import parse_h5_keys -from sed.loader.utils import split_dld_time_from_sector_id +from sed.loader.fel import BufferHandler +from sed.loader.fel.config_model import LoaderConfig +from sed.loader.sxp.dataframe import SXPDataFrameCreator class SXPLoader(BaseLoader): @@ -46,72 +36,24 @@ class SXPLoader(BaseLoader): supported_file_types = ["h5"] def __init__(self, config: dict) -> None: - super().__init__(config=config) - self.multi_index = ["trainId", "pulseId", "electronId"] - self.index_per_electron: MultiIndex = None - self.index_per_pulse: MultiIndex = None - self.failed_files_error: List[str] = [] - self.array_indices: List[List[slice]] = None - - def initialize_paths(self) -> Tuple[List[Path], Path]: """ - Initializes the paths based on the configuration. + Initializes the FlashLoader. - Returns: - Tuple[List[Path], Path]: A tuple containing a list of raw data directories - paths and the parquet data directory path. - - Raises: - ValueError: If required values are missing from the configuration. - FileNotFoundError: If the raw data directories are not found. + Args: + config (dict): The configuration dictionary or model. """ - # Parses to locate the raw beamtime directory from config file - if ( - "paths" in self._config["core"] - and self._config["core"]["paths"].get("data_raw_dir", "") - and self._config["core"]["paths"].get("data_parquet_dir", "") - ): - data_raw_dir = [ - Path(self._config["core"]["paths"].get("data_raw_dir", "")), - ] - data_parquet_dir = Path( - self._config["core"]["paths"].get("data_parquet_dir", ""), - ) - - else: - try: - beamtime_id = self._config["core"]["beamtime_id"] - year = self._config["core"]["year"] - except KeyError as exc: - raise ValueError( - "The beamtime_id and year are required.", - ) from exc - - beamtime_dir = Path( - self._config["dataframe"]["beamtime_dir"][self._config["core"]["beamline"]], - ) - beamtime_dir = beamtime_dir.joinpath(f"{year}/{beamtime_id}/") - - if not beamtime_dir.joinpath("raw").is_dir(): - raise FileNotFoundError("Raw data directory not found.") - - data_raw_dir = [beamtime_dir.joinpath("raw")] - - parquet_path = "processed/parquet" - data_parquet_dir = beamtime_dir.joinpath(parquet_path) - - data_parquet_dir.mkdir(parents=True, exist_ok=True) - - return data_raw_dir, data_parquet_dir + super().__init__(config=config) + self.config = LoaderConfig(**self._config) def get_files_from_run_id( self, run_id: str, - folders: Union[str, Sequence[str]] = None, + folders: str | Sequence[str] = None, extension: str = "h5", **kwds, - ) -> List[str]: - """Returns a list of filenames for a given run located in the specified directory + ) -> list[str]: + """ + Returns a list of filenames for a given run located in the specified directory for the specified data acquisition (daq). Args: @@ -119,8 +61,6 @@ def get_files_from_run_id( folders (Union[str, Sequence[str]], optional): The directory(ies) where the raw data is located. Defaults to config["core"]["base_folder"]. extension (str, optional): The file extension. Defaults to "h5". - kwds: Keyword arguments: - - daq (str): The data acquisition identifier. Returns: List[str]: A list of path strings representing the collected file names. @@ -129,25 +69,22 @@ def get_files_from_run_id( FileNotFoundError: If no files are found for the given run in the directory. """ # Define the stream name prefixes based on the data acquisition identifier - stream_name_prefixes = self._config["dataframe"]["stream_name_prefixes"] - stream_name_postfixes = self._config["dataframe"].get("stream_name_postfixes", {}) + stream_name_prefix = self.config.dataframe.stream_name_prefix + stream_name_postfix = self.config.dataframe.stream_name_postfix if isinstance(run_id, (int, np.integer)): run_id = str(run_id).zfill(4) if folders is None: - folders = self._config["core"]["base_folder"] + folders = self.config.core.base_folder if isinstance(folders, str): folders = [folders] - daq = kwds.pop("daq", self._config.get("dataframe", {}).get("daq")) - - stream_name_postfix = stream_name_postfixes.get(daq, "") # Generate the file patterns to search for in the directory - file_pattern = f"**/{stream_name_prefixes[daq]}{run_id}{stream_name_postfix}*." + extension + file_pattern = f"**/{stream_name_prefix}{run_id}{stream_name_postfix}*." + extension - files: List[Path] = [] + files: list[Path] = [] # Use pathlib to search for matching files in each directory for folder in folders: files.extend( @@ -166,721 +103,6 @@ def get_files_from_run_id( # Return the list of found files return [str(file.resolve()) for file in files] - @property - def available_channels(self) -> List: - """Returns the channel names that are available for use, - excluding pulseId, defined by the json file""" - available_channels = list(self._config["dataframe"]["channels"].keys()) - available_channels.remove("pulseId") - available_channels.remove("trainId") - return available_channels - - def get_channels(self, formats: Union[str, List[str]] = "", index: bool = False) -> List[str]: - """ - Returns a list of channels associated with the specified format(s). - - Args: - formats (Union[str, List[str]]): The desired format(s) - ('per_pulse', 'per_electron', 'per_train', 'all'). - index (bool): If True, includes channels from the multi_index. - - Returns: - List[str]: A list of channels with the specified format(s). - """ - # If 'formats' is a single string, convert it to a list for uniform processing. - if isinstance(formats, str): - formats = [formats] - - # If 'formats' is a string "all", gather all possible formats. - if formats == ["all"]: - channels = self.get_channels(["per_pulse", "per_train", "per_electron"], index) - return channels - - channels = [] - for format_ in formats: - # Gather channels based on the specified format(s). - channels.extend( - key - for key in self.available_channels - if self._config["dataframe"]["channels"][key]["format"] == format_ - and key != "dldAux" - ) - # Include 'dldAuxChannels' if the format is 'per_pulse'. - if format_ == "per_pulse" and "dldAux" in self._config["dataframe"]["channels"]: - channels.extend( - self._config["dataframe"]["channels"]["dldAux"]["dldAuxChannels"].keys(), - ) - - # Include channels from multi_index if 'index' is True. - if index: - channels.extend(self.multi_index) - - return channels - - def reset_multi_index(self) -> None: - """Resets the index per pulse and electron""" - self.index_per_electron = None - self.index_per_pulse = None - self.array_indices = None - - def create_multi_index_per_electron(self, h5_file: h5py.File) -> None: - """ - Creates an index per electron using pulseId for usage with the electron - resolved pandas DataFrame. - - Args: - h5_file (h5py.File): The HDF5 file object. - - Notes: - - This method relies on the 'pulseId' channel to determine - the macrobunch IDs. - - It creates a MultiIndex with trainId, pulseId, and electronId - as the index levels. - """ - - # relative macrobunch IDs obtained from the trainId channel - train_id, mab_array = self.create_numpy_array_per_channel( - h5_file, - "trainId", - ) - # Internal microbunch IDs obtained from the pulseId channel - train_id, mib_array = self.create_numpy_array_per_channel( - h5_file, - "pulseId", - ) - - # Chopping data into trains - macrobunch_index = [] - microbunch_ids = [] - macrobunch_indices = [] - for i in train_id.index: - # removing broken trailing hit copies - num_trains = self._config["dataframe"].get("num_trains", 0) - if num_trains: - try: - num_valid_hits = np.where(np.diff(mib_array[i].astype(np.int32)) < 0)[0][ - num_trains - 1 - ] - mab_array[i, num_valid_hits:] = 0 - mib_array[i, num_valid_hits:] = 0 - except IndexError: - pass - train_ends = np.where(np.diff(mib_array[i].astype(np.int32)) < -1)[0] - indices = [] - index = 0 - for train, train_end in enumerate(train_ends): - macrobunch_index.append(train_id[i] + np.uint(train)) - microbunch_ids.append(mib_array[i, index:train_end]) - indices.append(slice(index, train_end)) - index = train_end + 1 - macrobunch_indices.append(indices) - self.array_indices = macrobunch_indices - # Create a series with the macrobunches as index and - # microbunches as values - macrobunches = ( - Series( - (microbunch_ids[i] for i in range(len(macrobunch_index))), - name="pulseId", - index=macrobunch_index, - ) - - self._config["dataframe"]["ubid_offset"] - ) - - # Explode dataframe to get all microbunch vales per macrobunch, - # remove NaN values and convert to type int - microbunches = macrobunches.explode().dropna().astype(int) - - # Create temporary index values - index_temp = MultiIndex.from_arrays( - (microbunches.index, microbunches.values), - names=["trainId", "pulseId"], - ) - - # Calculate the electron counts per pulseId unique preserves the order of appearance - electron_counts = index_temp.value_counts()[index_temp.unique()].values - - # Series object for indexing with electrons - electrons = ( - Series( - [np.arange(electron_counts[i]) for i in range(electron_counts.size)], - ) - .explode() - .astype(int) - ) - - # Create a pandas MultiIndex using the exploded datasets - self.index_per_electron = MultiIndex.from_arrays( - (microbunches.index, microbunches.values, electrons), - names=self.multi_index, - ) - - def create_multi_index_per_pulse( - self, - train_id: Series, - np_array: np.ndarray, - ) -> None: - """ - Creates an index per pulse using a pulse resolved channel's macrobunch ID, for usage with - the pulse resolved pandas DataFrame. - - Args: - train_id (Series): The train ID Series. - np_array (np.ndarray): The numpy array containing the pulse resolved data. - - Notes: - - This method creates a MultiIndex with trainId and pulseId as the index levels. - """ - - # Create a pandas MultiIndex, useful for comparing electron and - # pulse resolved dataframes - self.index_per_pulse = MultiIndex.from_product( - (train_id, np.arange(0, np_array.shape[1])), - names=["trainId", "pulseId"], - ) - - def create_numpy_array_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> Tuple[Series, np.ndarray]: - """ - Returns a numpy array for a given channel name for a given file. - - Args: - h5_file (h5py.File): The h5py file object. - channel (str): The name of the channel. - - Returns: - Tuple[Series, np.ndarray]: A tuple containing the train ID Series and the numpy array - for the channel's data. - - """ - # Get the data from the necessary h5 file and channel - dataset = h5_file[self._config["dataframe"]["channels"][channel]["dataset_key"]] - index = h5_file[self._config["dataframe"]["channels"][channel]["index_key"]] - - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - train_id = Series(index, name="trainId") # macrobunch - - # unpacks the data into np.ndarray - np_array = dataset[()] - if len(np_array.shape) == 2 and self._config["dataframe"]["channels"][channel].get( - "max_hits", - 0, - ): - np_array = np_array[:, : self._config["dataframe"]["channels"][channel]["max_hits"]] - - # Use predefined axis and slice from the json file - # to choose correct dimension for necessary channel - if "slice" in channel_dict: - np_array = np.take( - np_array, - channel_dict["slice"], - axis=1, - ) - - if "scale" in channel_dict: - np_array = np_array / float(channel_dict["scale"]) - - return train_id, np_array - - def create_dataframe_per_electron( - self, - np_array: np.ndarray, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per electron]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - The microbunch resolved data is exploded and converted to a DataFrame. The MultiIndex - is set, and the NaN values are dropped, alongside the pulseId = 0 (meaningless). - - """ - if self.array_indices is None or len(self.array_indices) != np_array.shape[0]: - raise RuntimeError( - "macrobunch_indices not set correctly, internal inconstency detected.", - ) - train_data = [] - for i, _ in enumerate(self.array_indices): - for indices in self.array_indices[i]: - train_data.append(np_array[i, indices]) - return ( - Series((train for train in train_data), name=channel) - .explode() - .dropna() - .to_frame() - .set_index(self.index_per_electron) - .drop( - index=np.arange(-self._config["dataframe"]["ubid_offset"], 0), - level=1, - errors="ignore", - ) - ) - - def create_dataframe_per_pulse( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - channel_dict: dict, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per pulse]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - channel_dict (dict): The dictionary containing channel parameters. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - - Notes: - - For auxillary channels, the macrobunch resolved data is repeated 499 times to be - compared to electron resolved data for each auxillary channel. The data is then - converted to a multicolumn DataFrame. - - For all other pulse resolved channels, the macrobunch resolved data is exploded - to a DataFrame and the MultiIndex is set. - - """ - - # Special case for auxillary channels - if channel == "dldAux": - # Checks the channel dictionary for correct slices and creates a multicolumn DataFrame - data_frames = ( - Series( - (np_array[i, value] for i in train_id.index), - name=key, - index=train_id, - ).to_frame() - for key, value in channel_dict["dldAuxChannels"].items() - ) - - # Multiindex set and combined dataframe returned - data = reduce(DataFrame.combine_first, data_frames) - - # For all other pulse resolved channels - else: - # Macrobunch resolved data is exploded to a DataFrame and the MultiIndex is set - - # Creates the index_per_pulse for the given channel - self.create_multi_index_per_pulse(train_id, np_array) - data = ( - Series((np_array[i] for i in train_id.index), name=channel) - .explode() - .to_frame() - .set_index(self.index_per_pulse) - ) - - return data - - def create_dataframe_per_train( - self, - np_array: np.ndarray, - train_id: Series, - channel: str, - ) -> DataFrame: - """ - Returns a pandas DataFrame for a given channel name of type [per train]. - - Args: - np_array (np.ndarray): The numpy array containing the channel data. - train_id (Series): The train ID Series. - channel (str): The name of the channel. - - Returns: - DataFrame: The pandas DataFrame for the channel's data. - """ - return ( - Series((np_array[i] for i in train_id.index), name=channel) - .to_frame() - .set_index(train_id) - ) - - def create_dataframe_per_channel( - self, - h5_file: h5py.File, - channel: str, - ) -> Union[Series, DataFrame]: - """ - Returns a pandas DataFrame for a given channel name from a given file. - - This method takes an h5py.File object `h5_file` and a channel name `channel`, and returns - a pandas DataFrame containing the data for that channel from the file. The format of the - DataFrame depends on the channel's format specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. - channel (str): The name of the channel. - - Returns: - Union[Series, DataFrame]: A pandas Series or DataFrame representing the channel's data. - - Raises: - ValueError: If the channel has an undefined format. - - """ - [train_id, np_array] = self.create_numpy_array_per_channel( - h5_file, - channel, - ) # numpy Array created - channel_dict = self._config["dataframe"]["channels"][channel] # channel parameters - - # If np_array is size zero, fill with NaNs - if np_array.size == 0: - # Fill the np_array with NaN values of the same shape as train_id - np_array = np.full_like(train_id, np.nan, dtype=np.double) - # Create a Series using np_array, with train_id as the index - data = Series( - (np_array[i] for i in train_id.index), - name=channel, - index=train_id, - ) - - # Electron resolved data is treated here - if channel_dict["format"] == "per_electron": - # If index_per_electron is None, create it for the given file - if self.index_per_electron is None: - self.create_multi_index_per_electron(h5_file) - - # Create a DataFrame for electron-resolved data - data = self.create_dataframe_per_electron( - np_array, - channel, - ) - - # Pulse resolved data is treated here - elif channel_dict["format"] == "per_pulse": - # Create a DataFrame for pulse-resolved data - data = self.create_dataframe_per_pulse( - np_array, - train_id, - channel, - channel_dict, - ) - - # Train resolved data is treated here - elif channel_dict["format"] == "per_train": - # Create a DataFrame for train-resolved data - data = self.create_dataframe_per_train(np_array, train_id, channel) - - else: - raise ValueError( - channel - + "has an undefined format. Available formats are \ - per_pulse, per_electron and per_train", - ) - - return data - - def concatenate_channels( - self, - h5_file: h5py.File, - ) -> DataFrame: - """ - Concatenates the channels from the provided h5py.File into a pandas DataFrame. - - This method takes an h5py.File object `h5_file` and concatenates the channels present in - the file into a single pandas DataFrame. The concatenation is performed based on the - available channels specified in the configuration. - - Args: - h5_file (h5py.File): The h5py.File object representing the HDF5 file. - - Returns: - DataFrame: A concatenated pandas DataFrame containing the channels. - - Raises: - ValueError: If the group_name for any channel does not exist in the file. - - """ - all_keys = parse_h5_keys(h5_file) # Parses all channels present - - # Check for if the provided dataset_keys and index_keys actually exists in the file - for channel in self._config["dataframe"]["channels"]: - dataset_key = self._config["dataframe"]["channels"][channel]["dataset_key"] - if dataset_key not in all_keys: - raise ValueError( - f"The dataset_key for channel {channel} does not exist.", - ) - index_key = self._config["dataframe"]["channels"][channel]["index_key"] - if index_key not in all_keys: - raise ValueError( - f"The index_key for channel {channel} does not exist.", - ) - - # Create a generator expression to generate data frames for each channel - data_frames = ( - self.create_dataframe_per_channel(h5_file, each) for each in self.available_channels - ) - - # Use the reduce function to join the data frames into a single DataFrame - return reduce( - lambda left, right: left.join(right, how="outer"), - data_frames, - ) - - def create_dataframe_per_file( - self, - file_path: Path, - ) -> DataFrame: - """ - Create pandas DataFrames for the given file. - - This method loads an HDF5 file specified by `file_path` and constructs a pandas DataFrame - from the datasets within the file. The order of datasets in the DataFrames is the opposite - of the order specified by channel names. - - Args: - file_path (Path): Path to the input HDF5 file. - - Returns: - DataFrame: pandas DataFrame - - """ - # Loads h5 file and creates a dataframe - with h5py.File(file_path, "r") as h5_file: - self.reset_multi_index() # Reset MultiIndexes for next file - df = self.concatenate_channels(h5_file) - df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps")) - # correct the 3 bit shift which encodes the detector ID in the 8s time - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - df = split_dld_time_from_sector_id(df, config=self._config) - return df - - def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> Union[bool, Exception]: - """ - Converts an HDF5 file to Parquet format to create a buffer file. - - This method uses `create_dataframe_per_file` method to create dataframes from individual - files within an HDF5 file. The resulting dataframe is then saved to a Parquet file. - - Args: - h5_path (Path): Path to the input HDF5 file. - parquet_path (Path): Path to the output Parquet file. - - Raises: - ValueError: If an error occurs during the conversion process. - - """ - try: - ( - self.create_dataframe_per_file(h5_path) - .reset_index(level=self.multi_index) - .to_parquet(parquet_path, index=False) - ) - except Exception as exc: # pylint: disable=broad-except - self.failed_files_error.append(f"{parquet_path}: {type(exc)} {exc}") - return exc - return None - - def buffer_file_handler( - self, - data_parquet_dir: Path, - detector: str, - force_recreate: bool, - ) -> Tuple[List[Path], List, List]: - """ - Handles the conversion of buffer files (h5 to parquet) and returns the filenames. - - Args: - data_parquet_dir (Path): Directory where the parquet files will be stored. - detector (str): Detector name. - force_recreate (bool): Forces recreation of buffer files - - Returns: - Tuple[List[Path], List, List]: Three lists, one for - parquet file paths, one for metadata and one for schema. - - Raises: - FileNotFoundError: If the conversion fails for any files or no data is available. - """ - - # Create the directory for buffer parquet files - buffer_file_dir = data_parquet_dir.joinpath("buffer") - buffer_file_dir.mkdir(parents=True, exist_ok=True) - - # Create two separate lists for h5 and parquet file paths - h5_filenames = [Path(file) for file in self.files] - parquet_filenames = [ - buffer_file_dir.joinpath(Path(file).stem + detector) for file in self.files - ] - existing_parquet_filenames = [file for file in parquet_filenames if file.exists()] - - # Raise a value error if no data is available after the conversion - if len(h5_filenames) == 0: - raise ValueError("No data available. Probably failed reading all h5 files") - - if not force_recreate: - # Check if the available channels match the schema of the existing parquet files - parquet_schemas = [pq.read_schema(file) for file in existing_parquet_filenames] - config_schema = set(self.get_channels(formats="all", index=True)) - if self._config["dataframe"].get("split_sector_id_from_dld_time", False): - config_schema.add(self._config["dataframe"].get("sector_id_column", False)) - - for i, schema in enumerate(parquet_schemas): - schema_set = set(schema.names) - if schema_set != config_schema: - missing_in_parquet = config_schema - schema_set - missing_in_config = schema_set - config_schema - - missing_in_parquet_str = ( - f"Missing in parquet: {missing_in_parquet}" if missing_in_parquet else "" - ) - missing_in_config_str = ( - f"Missing in config: {missing_in_config}" if missing_in_config else "" - ) - - raise ValueError( - "The available channels do not match the schema of file", - f"{existing_parquet_filenames[i]}", - f"{missing_in_parquet_str}", - f"{missing_in_config_str}", - "Please check the configuration file or set force_recreate to True.", - ) - - # Choose files to read - files_to_read = [ - (h5_path, parquet_path) - for h5_path, parquet_path in zip(h5_filenames, parquet_filenames) - if force_recreate or not parquet_path.exists() - ] - - print(f"Reading files: {len(files_to_read)} new files of {len(h5_filenames)} total.") - - # Initialize the indices for create_buffer_file conversion - self.reset_multi_index() - - # Convert the remaining h5 files to parquet in parallel if there are any - if len(files_to_read) > 0: - error = Parallel(n_jobs=len(files_to_read), verbose=10)( - delayed(self.create_buffer_file)(h5_path, parquet_path) - for h5_path, parquet_path in files_to_read - ) - if any(error): - raise RuntimeError(f"Conversion failed for some files. {error}") - # for h5_path, parquet_path in files_to_read: - # self.create_buffer_file(h5_path, parquet_path) - - # Raise an error if the conversion failed for any files - # TODO: merge this and the previous error trackings - if self.failed_files_error: - raise FileNotFoundError( - "Conversion failed for the following files:\n" + "\n".join(self.failed_files_error), - ) - - print("All files converted successfully!") - - # read all parquet metadata and schema - metadata = [pq.read_metadata(file) for file in parquet_filenames] - schema = [pq.read_schema(file) for file in parquet_filenames] - - return parquet_filenames, metadata, schema - - def parquet_handler( - self, - data_parquet_dir: Path, - detector: str = "", - parquet_path: Path = None, - converted: bool = False, - load_parquet: bool = False, - save_parquet: bool = False, - force_recreate: bool = False, - ) -> Tuple[dd.DataFrame, dd.DataFrame]: - """ - Handles loading and saving of parquet files based on the provided parameters. - - Args: - data_parquet_dir (Path): Directory where the parquet files are located. - detector (str, optional): Adds a identifier for parquets to distinguish multidetector - systems. - parquet_path (str, optional): Path to the combined parquet file. - converted (bool, optional): True if data is augmented by adding additional columns - externally and saved into converted folder. - load_parquet (bool, optional): Loads the entire parquet into the dd dataframe. - save_parquet (bool, optional): Saves the entire dataframe into a parquet. - force_recreate (bool, optional): Forces recreation of buffer file. - Returns: - tuple: A tuple containing two dataframes: - - dataframe_electron: Dataframe containing the loaded/augmented electron data. - - dataframe_pulse: Dataframe containing the loaded/augmented timed data. - - Raises: - FileNotFoundError: If the requested parquet file is not found. - - """ - - # Construct the parquet path if not provided - if parquet_path is None: - parquet_name = "_".join(str(run) for run in self.runs) - parquet_dir = data_parquet_dir.joinpath("converted") if converted else data_parquet_dir - - parquet_path = parquet_dir.joinpath( - "run_" + parquet_name + detector, - ).with_suffix(".parquet") - - # Check if load_parquet is flagged and then load the file if it exists - if load_parquet: - try: - dataframe = dd.read_parquet(parquet_path) - except Exception as exc: - raise FileNotFoundError( - "The final parquet for this run(s) does not exist yet. " - "If it is in another location, please provide the path as parquet_path.", - ) from exc - - else: - # Obtain the parquet filenames, metadata and schema from the method - # which handles buffer file creation/reading - filenames, metadata, _ = self.buffer_file_handler( - data_parquet_dir, - detector, - force_recreate, - ) - - # Read all parquet files into one dataframe using dask - dataframe = dd.read_parquet(filenames, calculate_divisions=True) - - # Channels to fill NaN values - channels: List[str] = self.get_channels(["per_pulse", "per_train"]) - - overlap = min(file.num_rows for file in metadata) - - print("Filling nan values...") - dataframe = dfops.forward_fill_lazy( - df=dataframe, - columns=channels, - before=overlap, - iterations=self._config["dataframe"].get("forward_fill_iterations", 2), - ) - # Remove the NaNs from per_electron channels - dataframe_electron = dataframe.dropna( - subset=self.get_channels(["per_electron"]), - ) - dataframe_pulse = dataframe[ - self.multi_index + self.get_channels(["per_pulse", "per_train"]) - ] - dataframe_pulse = dataframe_pulse[ - (dataframe_pulse["electronId"] == 0) | (np.isnan(dataframe_pulse["electronId"])) - ] - - # Save the dataframe as parquet if requested - if save_parquet: - dataframe_electron.compute().reset_index(drop=True).to_parquet(parquet_path) - print("Combined parquet file saved.") - - return dataframe_electron, dataframe_pulse - def gather_metadata(self, metadata: dict = None) -> dict: """Dummy function returning empty metadata dictionary for now. @@ -893,7 +115,6 @@ def gather_metadata(self, metadata: dict = None) -> dict: """ if metadata is None: metadata = {} - return metadata def get_count_rate( @@ -908,14 +129,16 @@ def get_elapsed_time(self, fids=None, **kwds): def read_dataframe( self, - files: Union[str, Sequence[str]] = None, - folders: Union[str, Sequence[str]] = None, - runs: Union[str, Sequence[str]] = None, + files: str | Sequence[str] = None, + folders: str | Sequence[str] = None, + runs: str | Sequence[str] = None, ftype: str = "h5", metadata: dict = None, collect_metadata: bool = False, + force_recreate: bool = False, + debug: bool = False, **kwds, - ) -> Tuple[dd.DataFrame, dd.DataFrame, dict]: + ) -> tuple[dd.DataFrame, dd.DataFrame, dict]: """ Read express data from the DAQ, generating a parquet in between. @@ -925,14 +148,15 @@ def read_dataframe( Path has priority such that if it's specified, the specified files will be ignored. Defaults to None. runs (Union[str, Sequence[str]], optional): Run identifier(s). Corresponding files will - be located in the location provided by ``folders``. Takes precendence over + be located in the location provided by ``folders``. Takes precedence over ``files`` and ``folders``. Defaults to None. ftype (str, optional): The file extension type. Defaults to "h5". metadata (dict, optional): Additional metadata. Defaults to None. collect_metadata (bool, optional): Whether to collect metadata. Defaults to False. Returns: - Tuple[dd.DataFrame, dict]: A tuple containing the concatenated DataFrame and metadata. + Tuple[dd.DataFrame, dd.DataFrame, dict]: A tuple containing the concatenated DataFrame + and metadata. Raises: ValueError: If neither 'runs' nor 'files'/'data_raw_dir' is provided. @@ -940,7 +164,9 @@ def read_dataframe( """ t0 = time.time() - data_raw_dir, data_parquet_dir = self.initialize_paths() + paths = self.config.core.paths + data_raw_dir = paths.data_raw_dir + data_parquet_dir = paths.data_parquet_dir # Prepare a list of names for the runs to read and parquets to write if runs is not None: @@ -950,9 +176,8 @@ def read_dataframe( for run in runs: run_files = self.get_files_from_run_id( run_id=run, - folders=[str(folder.resolve()) for folder in data_raw_dir], + folders=[str(folder.resolve()) for folder in [data_raw_dir]], extension=ftype, - daq=self._config["dataframe"]["daq"], ) files.extend(run_files) self.runs = list(runs) @@ -968,7 +193,19 @@ def read_dataframe( metadata=metadata, ) - df, df_timed = self.parquet_handler(data_parquet_dir, **kwds) + # Obtain the parquet filenames, metadata, and schema from the method + # which handles buffer file creation/reading + h5_paths = [Path(file) for file in self.files] + buffer = BufferHandler( + SXPDataFrameCreator, + self.config.dataframe, + h5_paths, + data_parquet_dir, + force_recreate, + debug=debug, + ) + df = buffer.dataframe_electron + df_timed = buffer.dataframe_pulse if collect_metadata: metadata = self.gather_metadata( diff --git a/sed/loader/utils.py b/sed/loader/utils.py index ab3fde3a..0d26671c 100644 --- a/sed/loader/utils.py +++ b/sed/loader/utils.py @@ -162,7 +162,7 @@ def split_dld_time_from_sector_id( sector_id_column (str, optional): Name of the column containing the sectorID. Defaults to config["dataframe"]["sector_id_column"]. sector_id_reserved_bits (int, optional): Number of bits reserved for the - config (dict, optional): Configuration dictionary. Defaults to None. + config (dict, optional): Dataframe configuration dictionary. Defaults to None. Returns: Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new columns. @@ -170,15 +170,15 @@ def split_dld_time_from_sector_id( if tof_column is None: if config is None: raise ValueError("Either tof_column or config must be given.") - tof_column = config["dataframe"]["tof_column"] + tof_column = config["tof_column"] if sector_id_column is None: if config is None: raise ValueError("Either sector_id_column or config must be given.") - sector_id_column = config["dataframe"]["sector_id_column"] + sector_id_column = config["sector_id_column"] if sector_id_reserved_bits is None: if config is None: raise ValueError("Either sector_id_reserved_bits or config must be given.") - sector_id_reserved_bits = config["dataframe"].get("sector_id_reserved_bits", None) + sector_id_reserved_bits = config.get("sector_id_reserved_bits", None) if sector_id_reserved_bits is None: raise ValueError('No value for "sector_id_reserved_bits" found in config.') diff --git a/tests/calibrator/test_energy.py b/tests/calibrator/test_energy.py index ed5a40bf..d3e35b65 100644 --- a/tests/calibrator/test_energy.py +++ b/tests/calibrator/test_energy.py @@ -24,6 +24,8 @@ folder = package_dir + "/../tests/data/calibrator/" files = glob.glob(df_folder + "*.h5") +config_path_flash = os.path.join(package_dir, "../tests/data/loader/flash/config.yaml") + traces_list = [] with open(folder + "traces.csv", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile, quoting=csv.QUOTE_NONNUMERIC) @@ -584,7 +586,7 @@ def test_add_offsets_functionality() -> None: }, }, }, - folder_config={}, + folder_config=config_path_flash, user_config={}, system_config={}, ) @@ -661,7 +663,12 @@ def test_add_offset_raises() -> None: with pytest.raises(KeyError): cfg = deepcopy(cfg_dict) cfg["energy"]["offsets"]["off1"].pop("weight") - config = parse_config(config=cfg, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg, loader=get_loader("flash", config=config)) _ = ec.add_offsets(t_df) @@ -669,7 +676,12 @@ def test_add_offset_raises() -> None: with pytest.raises(ValueError): cfg = deepcopy(cfg_dict) cfg["energy"]["calibration"].pop("energy_scale") - config = parse_config(config=cfg, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg, loader=get_loader("flash", config=config)) _ = ec.add_offsets(t_df) @@ -677,7 +689,12 @@ def test_add_offset_raises() -> None: with pytest.raises(ValueError): cfg = deepcopy(cfg_dict) cfg["energy"]["calibration"]["energy_scale"] = "wrong_value" - config = parse_config(config=cfg, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg, loader=get_loader("flash", config=config)) _ = ec.add_offsets(t_df) @@ -685,7 +702,12 @@ def test_add_offset_raises() -> None: with pytest.raises(TypeError): cfg = deepcopy(cfg_dict) cfg["energy"]["offsets"]["off1"]["weight"] = "wrong_type" - config = parse_config(config=cfg, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg, loader=get_loader("flash", config=config)) _ = ec.add_offsets(t_df) @@ -693,7 +715,12 @@ def test_add_offset_raises() -> None: with pytest.raises(TypeError): cfg = deepcopy(cfg_dict) cfg["energy"]["offsets"]["constant"] = "wrong_type" - config = parse_config(config=cfg, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg, loader=get_loader("flash", config=config)) _ = ec.add_offsets(t_df) @@ -714,7 +741,12 @@ def test_align_dld_sectors() -> None: }, ) # from config - config = parse_config(config=cfg_dict, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=cfg_dict, + folder_config=config_path_flash, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=config, loader=get_loader("flash", config=config)) t_df = dask.dataframe.from_pandas(df.copy(), npartitions=2) res, meta = ec.align_dld_sectors(t_df) @@ -726,7 +758,12 @@ def test_align_dld_sectors() -> None: ) # from kwds - config = parse_config(config={}, folder_config={}, user_config={}, system_config={}) + config = parse_config( + config=config_path_flash, + folder_config={}, + user_config={}, + system_config={}, + ) ec = EnergyCalibrator(config=cfg_dict, loader=get_loader("flash", config=config)) t_df = dask.dataframe.from_pandas(df.copy(), npartitions=2) res, meta = ec.align_dld_sectors( diff --git a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 index 02a04d9f..1eafbf33 100644 Binary files a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 and b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5 differ diff --git a/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 new file mode 100644 index 00000000..1524e5ad Binary files /dev/null and b/tests/data/loader/flash/FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5 differ diff --git a/tests/data/loader/flash/config.yaml b/tests/data/loader/flash/config.yaml index b7049dfb..2b1757ca 100644 --- a/tests/data/loader/flash/config.yaml +++ b/tests/data/loader/flash/config.yaml @@ -81,23 +81,28 @@ dataframe: # pulse ID is a necessary channel for using the loader. pulseId: format: per_electron - group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + # group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" + index_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/index" + dataset_key: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/value" slice: 2 dldPosX: format: per_electron group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" slice: 1 + dtype: uint16 dldPosY: format: per_electron group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" slice: 0 + dtype: uint16 dldTimeSteps: format: per_electron group_name: "/uncategorised/FLASH.EXP/HEXTOF.DAQ/DLD1/" slice: 3 + dtype: uint32 # The auxillary channel has a special structure where the group further contains # a multidim structure so further aliases are defined below @@ -107,7 +112,8 @@ dataframe: slice: 4 dldAuxChannels: sampleBias: 0 - tofVoltage: 1 + tofVoltage: + slice: 1 extractorVoltage: 2 extractorCurrent: 3 cryoTemperature: 4 @@ -122,9 +128,14 @@ dataframe: format: per_train group_name: "/zraw/FLASH.SYNC/LASER.LOCK.EXP/F1.PG.OSC/FMC0.MD22.1.ENCODER_POSITION.RD/dGroup/" + pulserSignAdc: + format: per_pulse + group_name: "/FL1/Experiment/PG/SIS8300 100MHz ADC/CH6/TD/" + gmdTunnel: format: per_pulse group_name: "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/" + slice: 0 # The prefixes of the stream names for different DAQ systems for parsing filenames # (Not to be changed by user) diff --git a/tests/loader/conftest.py b/tests/loader/conftest.py new file mode 100644 index 00000000..3dfc345b --- /dev/null +++ b/tests/loader/conftest.py @@ -0,0 +1,126 @@ +""" This module contains fixtures for the FEL module tests. +""" +import os +import shutil +from importlib.util import find_spec + +import h5py +import pytest + +from sed.core.config import parse_config +from sed.loader.fel.config_model import DataFrameConfig +from sed.loader.fel.config_model import LoaderConfig + +package_dir = os.path.dirname(find_spec("sed").origin) +config_path = os.path.join(package_dir, "../tests/data/loader/flash/config.yaml") +H5_PATH = "FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5" +H5_PATHS = [H5_PATH, "FLASH1_USER3_stream_2_run43879_file1_20230130T153807.1.h5"] + + +@pytest.fixture(name="config_raw") +def fixture_config_raw_file() -> dict: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return parse_config(config_path) + + +@pytest.fixture(name="config") +def fixture_config_file(config_raw) -> LoaderConfig: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return LoaderConfig(**config_raw) + + +@pytest.fixture(name="config_dataframe") +def fixture_config_file_dataframe(config) -> DataFrameConfig: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return config.dataframe + + +@pytest.fixture(name="h5_file") +def fixture_h5_file(): + """Fixture providing an open h5 file. + + Returns: + h5py.File: The open h5 file. + """ + return h5py.File(os.path.join(package_dir, f"../tests/data/loader/flash/{H5_PATH}"), "r") + + +@pytest.fixture(name="h5_file_copy") +def fixture_h5_file_copy(tmp_path): + """Fixture providing a copy of an open h5 file. + + Returns: + h5py.File: The open h5 file copy. + """ + # Create a copy of the h5 file in a temporary directory + original_file_path = os.path.join(package_dir, f"../tests/data/loader/flash/{H5_PATH}") + copy_file_path = tmp_path / "copy.h5" + shutil.copyfile(original_file_path, copy_file_path) + + # Open the copy in 'read-write' mode and return it + return h5py.File(copy_file_path, "r+") + + +@pytest.fixture(name="h5_paths") +def fixture_h5_paths(): + """Fixture providing a list of h5 file paths. + + Returns: + list: A list of h5 file paths. + """ + return [os.path.join(package_dir, f"../tests/data/loader/flash/{path}") for path in H5_PATHS] + + +# @pytest.fixture(name="pulserSignAdc_channel_array") +# def get_pulse_channel_from_h5(config_dataframe, h5_file): +# df = DataFrameCreator(config_dataframe) +# df.h5_file = h5_file +# train_id, pulse_id = df.get_dataset_array("pulserSignAdc") +# return train_id, pulse_id + + +# @pytest.fixture(name="multiindex_electron") +# def fixture_multi_index_electron(config_dataframe, h5_file): +# """Fixture providing multi index for electron resolved data""" +# df = DataFrameCreator(config_dataframe) +# df.h5_file = h5_file +# pulse_index, indexer = df.pulse_index(config_dataframe["ubid_offset"]) + +# return pulse_index, indexer + + +# @pytest.fixture(name="fake_data") +# def fake_data_electron(): +# # Creating manageable fake data, but not used currently +# num_trains = 5 +# max_pulse_id = 100 +# nan_threshold = 50 +# ubid_offset = 5 +# seed = 42 +# np.random.seed(seed) +# train_ids = np.arange(1600000000, 1600000000 + num_trains) +# fake_data = [] + +# for _ in train_ids: +# pulse_ids = [] +# while len(pulse_ids) < nan_threshold: +# random_pulse_ids = np.random.choice( +# np.arange(ubid_offset, nan_threshold), size=np.random.randint(0, 10)) +# pulse_ids = np.concatenate([pulse_ids, random_pulse_ids]) + +# pulse_ids = np.concatenate([pulse_ids, np.full(max_pulse_id-len(pulse_ids), np.nan)]) + +# fake_data.append(np.sort(pulse_ids)) +# return Series(train_ids, name="trainId"), np.array(fake_data), ubid_offset diff --git a/tests/loader/fel/__init__.py b/tests/loader/fel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/loader/fel/test_buffer_handler.py b/tests/loader/fel/test_buffer_handler.py new file mode 100644 index 00000000..86fbeb52 --- /dev/null +++ b/tests/loader/fel/test_buffer_handler.py @@ -0,0 +1,178 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from sed.loader.fel import BufferHandler +from sed.loader.fel.utils import get_channels +from sed.loader.flash.dataframe import FlashDataFrameCreator + + +def create_parquet_dir(config, folder): + parquet_path = Path(config.core.paths.data_parquet_dir) + parquet_path = parquet_path.joinpath(folder) + parquet_path.mkdir(parents=True, exist_ok=True) + return parquet_path + + +def test_get_files_to_read(config, h5_paths): + folder = create_parquet_dir(config, "get_files_to_read") + subfolder = folder.joinpath("buffer") + # set to false to avoid creating buffer files unnecessarily + bh = BufferHandler(FlashDataFrameCreator, config.dataframe, h5_paths, folder, auto=False) + bh.get_files_to_read(h5_paths, folder, "", "", False) + + assert bh.num_files == len(h5_paths) + assert len(bh.buffer_to_create) == len(h5_paths) + + assert np.all(bh.h5_to_create == h5_paths) + + # create expected paths + expected_buffer_paths = [Path(subfolder, f"{Path(path).stem}") for path in h5_paths] + + assert np.all(bh.buffer_to_create == expected_buffer_paths) + + # create only one buffer file + bh._create_buffer_file(h5_paths[0], expected_buffer_paths[0]) + # check again for files to read + bh.get_files_to_read(h5_paths, folder, "", "", False) + # check that only one file is to be read + assert bh.num_files == len(h5_paths) - 1 + Path(expected_buffer_paths[0]).unlink() # remove buffer file + + # add prefix and suffix + bh.get_files_to_read(h5_paths, folder, "prefix_", "_suffix", False) + + # expected buffer paths with prefix and suffix + expected_buffer_paths = [ + Path(subfolder, f"prefix_{Path(path).stem}_suffix") for path in h5_paths + ] + assert np.all(bh.buffer_to_create == expected_buffer_paths) + + +def test_buffer_schema_mismatch(config, h5_paths): + """ + Test function to verify schema mismatch handling in the FlashLoader's 'read_dataframe' method. + + The test validates the error handling mechanism when the available channels do not match the + schema of the existing parquet files. + + Test Steps: + - Attempt to read a dataframe after adding a new channel 'gmdTunnel2' to the configuration. + - Check for an expected error related to the mismatch between available channels and schema. + - Force recreation of dataframe with the added channel, ensuring successful dataframe + creation. + - Simulate a missing channel scenario by removing 'gmdTunnel2' from the configuration. + - Check for an error indicating a missing channel in the configuration. + - Clean up created buffer files after the test. + """ + folder = create_parquet_dir(config, "schema_mismatch") + bh = BufferHandler( + FlashDataFrameCreator, + config.dataframe, + h5_paths, + folder, + auto=True, + debug=True, + ) + + # Manipulate the configuration to introduce a new channel 'gmdTunnel2' + config_alt = config + gmdTunnel2 = config_alt.dataframe.channels["gmdTunnel"] + gmdTunnel2.group_name = "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/" + gmdTunnel2.format = "per_pulse" + gmdTunnel2.slice = 0 + config_alt.dataframe.channels["gmdTunnel2"] = gmdTunnel2 + + # Reread the dataframe with the modified configuration, expecting a schema mismatch error + with pytest.raises(ValueError) as e: + bh = BufferHandler( + FlashDataFrameCreator, + config.dataframe, + h5_paths, + folder, + auto=True, + debug=True, + ) + expected_error = e.value.args + + # Validate the specific error messages for schema mismatch + assert "The available channels do not match the schema of file" in expected_error[0] + assert expected_error[2] == "Missing in parquet: {'gmdTunnel2'}" + assert expected_error[4] == "Please check the configuration file or set force_recreate to True." + + # Force recreation of the dataframe, including the added channel 'gmdTunnel2' + bh = BufferHandler( + FlashDataFrameCreator, + config.dataframe, + h5_paths, + folder, + auto=True, + force_recreate=True, + debug=True, + ) + + # Remove 'gmdTunnel2' from the configuration to simulate a missing channel scenario + del config.dataframe.channels["gmdTunnel2"] + # also results in error but different from before + with pytest.raises(ValueError) as e: + # Attempt to read the dataframe again to check for the missing channel error + bh = BufferHandler( + FlashDataFrameCreator, + config.dataframe, + h5_paths, + folder, + auto=True, + debug=True, + ) + + expected_error = e.value.args + # Check for the specific error message indicating a missing channel in the configuration + assert expected_error[3] == "Missing in config: {'gmdTunnel2'}" + + # Clean up created buffer files after the test + [path.unlink() for path in bh.buffer_paths] + + +def test_create_buffer_files(config, h5_paths): + folder_serial = create_parquet_dir(config, "create_buffer_files_serial") + bh_serial = BufferHandler( + FlashDataFrameCreator, + config.dataframe, + h5_paths, + folder_serial, + debug=True, + ) + + folder_parallel = create_parquet_dir(config, "create_buffer_files_parallel") + bh_parallel = BufferHandler(FlashDataFrameCreator, config.dataframe, h5_paths, folder_parallel) + + df_serial = pd.read_parquet(folder_serial) + df_parallel = pd.read_parquet(folder_parallel) + + pd.testing.assert_frame_equal(df_serial, df_parallel) + + # remove buffer files + [path.unlink() for path in bh_serial.buffer_paths] + [path.unlink() for path in bh_parallel.buffer_paths] + + +def test_get_filled_dataframe(config, h5_paths): + """Test function to verify the creation of a filled dataframe from the buffer files.""" + folder = create_parquet_dir(config, "get_filled_dataframe") + bh = BufferHandler(FlashDataFrameCreator, config.dataframe, h5_paths, folder) + + df = pd.read_parquet(folder) + + assert np.all(list(bh.dataframe_electron.columns) == list(df.columns) + ["dldSectorID"]) + + channel_pulse = get_channels( + config.dataframe.channels, + formats=["per_pulse", "per_train"], + index=True, + extend_aux=True, + ) + assert np.all(list(bh.dataframe_pulse.columns) == channel_pulse) + # remove buffer files + [path.unlink() for path in bh.buffer_paths] diff --git a/tests/loader/fel/test_parquet_handler.py b/tests/loader/fel/test_parquet_handler.py new file mode 100644 index 00000000..0645f97e --- /dev/null +++ b/tests/loader/fel/test_parquet_handler.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import dask.dataframe as ddf +import pandas as pd +import pytest + +from sed.loader.fel import BufferHandler +from sed.loader.fel import ParquetHandler +from sed.loader.flash.dataframe import FlashDataFrameCreator + + +def create_parquet_dir(config, folder): + parquet_path = Path(config.core.paths.data_parquet_dir) + parquet_path = parquet_path.joinpath(folder) + parquet_path.mkdir(parents=True, exist_ok=True) + return parquet_path + + +def test_parquet_init_error(): + """Test ParquetHandler initialization error""" + with pytest.raises(ValueError) as e: + ParquetHandler(parquet_names="test") + + assert "Please provide folder or parquet_paths." in str(e.value) + + with pytest.raises(ValueError) as e: + ParquetHandler(folder="test") + + assert "With folder, please provide parquet_names." in str(e.value) + + +def test_initialize_paths(config): + """Test ParquetHandler initialization""" + folder = create_parquet_dir(config, "parquet_init") + + ph = ParquetHandler("test", folder, extension="xyz") + assert ph.parquet_paths[0].suffix == ".xyz" + assert ph.parquet_paths[0].name == "test.xyz" + + # test prefix and suffix + ph = ParquetHandler("test", folder, prefix="prefix_", suffix="_suffix") + assert ph.parquet_paths[0].name == "prefix_test_suffix.parquet" + + # test with list of parquet_names and subfolder + ph = ParquetHandler(["test1", "test2"], folder, subfolder="subfolder") + assert ph.parquet_paths[0].parent.name == "subfolder" + assert ph.parquet_paths[0].name == "test1.parquet" + assert ph.parquet_paths[1].name == "test2.parquet" + + +def test_save_read_parquet(config, h5_paths): + """Test ParquetHandler save and read parquet""" + # provide instead parquet_paths + folder = create_parquet_dir(config, "parquet_save") + parquet_path = folder.joinpath("test.parquet") + + ph = ParquetHandler(parquet_paths=parquet_path) + print(ph.parquet_paths) + bh = BufferHandler(FlashDataFrameCreator, config.dataframe, h5_paths, folder) + ph.save_parquet(bh.dataframe_electron, drop_index=True) + parquet_path.unlink() + ph.save_parquet(bh.dataframe_electron, drop_index=False) + + df = ph.read_parquet() + + [path.unlink() for path in bh.buffer_paths] + parquet_path.unlink() + # Test file not found + with pytest.raises(FileNotFoundError) as e: + ph.read_parquet() diff --git a/tests/loader/fel/test_utils.py b/tests/loader/fel/test_utils.py new file mode 100644 index 00000000..b482f9c4 --- /dev/null +++ b/tests/loader/fel/test_utils.py @@ -0,0 +1,74 @@ +"""Tests for utils functionality""" +from sed.loader.fel.utils import get_channels + +# Define expected channels for each format. +ELECTRON_CHANNELS = ["dldPosX", "dldPosY", "dldTimeSteps"] +PULSE_CHANNELS = ["pulserSignAdc", "gmdTunnel"] +TRAIN_CHANNELS = ["timeStamp", "delayStage", "dldAux"] +TRAIN_CHANNELS_EXTENDED = [ + "sampleBias", + "tofVoltage", + "extractorVoltage", + "extractorCurrent", + "cryoTemperature", + "sampleTemperature", + "dldTimeBinSize", + "timeStamp", + "delayStage", +] +INDEX_CHANNELS = ["trainId", "pulseId", "electronId"] + + +def test_get_channels_by_format(config_dataframe): + """ + Test function to verify the 'get_channels' method in FlashLoader class for + retrieving channels based on formats and index inclusion. + """ + # Initialize the FlashLoader instance with the given config_file. + ch_dict = config_dataframe.channels + + # Call get_channels method with different format options. + + # Request channels for 'per_electron' format using a list. + format_electron = get_channels(ch_dict, ["per_electron"]) + + # Request channels for 'per_pulse' format using a string. + format_pulse = get_channels(ch_dict, "per_pulse") + + # Request channels for 'per_train' format without expanding the dldAuxChannels. + format_train = get_channels(ch_dict, "per_train", extend_aux=False) + + # Request channels for 'per_train' format using a list, and expand the dldAuxChannels. + format_train_extended = get_channels(ch_dict, ["per_train"], extend_aux=True) + + # Request channels for 'all' formats using a list. + format_all = get_channels(ch_dict, ["all"]) + + # Request index channels only. No need for channel_dict. + format_index = get_channels(index=True) + + # Request 'per_electron' format and include index channels. + format_index_electron = get_channels(ch_dict, ["per_electron"], index=True) + + # Request 'all' formats and include index channels. + format_all_index = get_channels(ch_dict, ["all"], index=True) + + # Request 'all' formats and include index channels and extend aux channels + format_all_index_extend_aux = get_channels(ch_dict, "all", index=True, extend_aux=True) + + # Assert that the obtained channels match the expected channels. + assert set(ELECTRON_CHANNELS) == set(format_electron) + assert set(TRAIN_CHANNELS_EXTENDED) == set(format_train_extended) + assert set(TRAIN_CHANNELS) == set(format_train) + assert set(PULSE_CHANNELS) == set(format_pulse) + assert set(ELECTRON_CHANNELS + TRAIN_CHANNELS + PULSE_CHANNELS) == set(format_all) + assert set(INDEX_CHANNELS) == set(format_index) + assert set(INDEX_CHANNELS + ELECTRON_CHANNELS) == set(format_index_electron) + assert set(INDEX_CHANNELS + ELECTRON_CHANNELS + TRAIN_CHANNELS + PULSE_CHANNELS) == set( + format_all_index, + ) + assert set( + INDEX_CHANNELS + ELECTRON_CHANNELS + PULSE_CHANNELS + TRAIN_CHANNELS_EXTENDED, + ) == set( + format_all_index_extend_aux, + ) diff --git a/tests/loader/flash/test_dataframe_creator.py b/tests/loader/flash/test_dataframe_creator.py new file mode 100644 index 00000000..664c443e --- /dev/null +++ b/tests/loader/flash/test_dataframe_creator.py @@ -0,0 +1,239 @@ +"""Tests for FlashDataFrameCreator functionality""" +import h5py +import numpy as np +import pytest +from pandas import DataFrame +from pandas import Index +from pandas import MultiIndex + +from sed.loader.fel.utils import get_channels +from sed.loader.flash.dataframe import FlashDataFrameCreator + + +def test_get_dataset_array(config_dataframe, h5_file): + """Test the creation of a h5py dataset for a given channel.""" + + df = FlashDataFrameCreator(config_dataframe, h5_file) + channel = "dldPosX" + + train_id, dset = df.get_dataset_array(channel) + # Check that the train_id and np_array have the correct shapes and types + assert isinstance(train_id, Index) + assert isinstance(dset, h5py.Dataset) + assert train_id.name == "trainId" + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 5 + assert dset.shape[2] == 321 + + train_id, dset = df.get_dataset_array(channel, slice_=True) + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 321 + + channel = "gmdTunnel" + train_id, dset = df.get_dataset_array(channel, True) + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == 500 + + +def test_empty_get_dataset_array(config_dataframe, h5_file, h5_file_copy): + """Test the method when given an empty dataset.""" + + channel = "gmdTunnel" + df = FlashDataFrameCreator(config_dataframe, h5_file) + train_id, dset = df.get_dataset_array(channel) + + channel_index_key = config_dataframe.channels.get(channel).group_name + "index" + empty_dataset_key = config_dataframe.channels.get(channel).group_name + "empty" + config_dataframe.channels.get(channel).index_key = channel_index_key + config_dataframe.channels.get(channel).dataset_key = empty_dataset_key + # Remove the 'group_name' key + del config_dataframe.channels.get(channel).group_name + + # create an empty dataset + h5_file_copy.create_dataset( + name=empty_dataset_key, + shape=(train_id.shape[0], 0), + ) + + df = FlashDataFrameCreator(config_dataframe, h5_file) + df.h5_file = h5_file_copy + train_id, dset_empty = df.get_dataset_array(channel) + + assert dset_empty.shape[0] == train_id.shape[0] + assert dset.shape[1] == 8 + assert dset_empty.shape[1] == 0 + + +def test_pulse_index(config_dataframe, h5_file): + """Test the creation of the pulse index for electron resolved data""" + + df = FlashDataFrameCreator(config_dataframe, h5_file) + pulse_index, pulse_array = df.get_dataset_array("pulseId", slice_=True) + index, indexer = df.pulse_index(config_dataframe.ubid_offset) + # Check if the index_per_electron is a MultiIndex and has the correct levels + assert isinstance(index, MultiIndex) + assert set(index.names) == {"trainId", "pulseId", "electronId"} + + # Check if the pulse_index has the correct number of elements + # This should be the pulses without nan values + pulse_rav = pulse_array.ravel() + pulse_no_nan = pulse_rav[~np.isnan(pulse_rav)] + assert len(index) == len(pulse_no_nan) + + # Check if all pulseIds are correctly mapped to the index + assert np.all( + index.get_level_values("pulseId").values + == (pulse_no_nan - config_dataframe.ubid_offset)[indexer], + ) + + assert np.all( + index.get_level_values("electronId").values[:5] == [0, 1, 0, 1, 0], + ) + + assert np.all( + index.get_level_values("electronId").values[-5:] == [1, 0, 1, 0, 1], + ) + + # check if all indexes are unique and monotonic increasing + assert index.is_unique + assert index.is_monotonic_increasing + + +def test_df_electron(config_dataframe, h5_file): + """Test the creation of a pandas DataFrame for a channel of type [per electron].""" + df = FlashDataFrameCreator(config_dataframe, h5_file) + + result_df = df.df_electron + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # check that there are no nan values in the dataframe + assert ~result_df.isnull().values.any() + + # Check that the values are dropped for pulseId index below 0 (ubid_offset) + print( + np.all( + result_df.values[:5] + != np.array( + [ + [556.0, 731.0, 42888.0], + [549.0, 737.0, 42881.0], + [671.0, 577.0, 39181.0], + [671.0, 579.0, 39196.0], + [714.0, 859.0, 37530.0], + ], + ), + ), + ) + assert np.all(result_df.index.get_level_values("pulseId") >= 0) + assert isinstance(result_df, DataFrame) + + assert result_df.index.is_unique + + # check that dataframe contains all subchannels + assert np.all( + set(result_df.columns) == set(get_channels(config_dataframe.channels, ["per_electron"])), + ) + + +def test_create_dataframe_per_pulse(config_dataframe, h5_file): + """Test the creation of a pandas DataFrame for a channel of type [per pulse].""" + df = FlashDataFrameCreator(config_dataframe, h5_file) + result_df = df.df_pulse + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + + _, data = df.get_dataset_array("gmdTunnel", slice_=True) + assert result_df.shape[0] == data.shape[0] * data.shape[1] + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # all electronIds should be 0 + assert np.all(result_df.index.get_level_values("electronId") == 0) + + # pulse ids should span 0-499 on each train + assert np.all( + result_df.loc[1648851402].index.get_level_values("pulseId").values == np.arange(500), + ) + + # assert index uniqueness + assert result_df.index.is_unique + + # assert that dataframe contains all channels + assert np.all( + set(result_df.columns) == set(get_channels(config_dataframe.channels, ["per_pulse"])), + ) + + +def test_create_dataframe_per_train(config_dataframe, h5_file): + """Test the creation of a pandas DataFrame for a channel of type [per train].""" + df = FlashDataFrameCreator(config_dataframe, h5_file) + result_df = df.df_train + + channel = "delayStage" + key, data = df.get_dataset_array(channel, slice_=True) + + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + + # check that all values are in the df for delayStage + np.all(result_df[channel].dropna() == data[()]) + + # check that dataframe contains all channels + assert np.all( + set(result_df.columns) + == set(get_channels(config_dataframe.channels, ["per_train"], extend_aux=True)), + ) + + # find unique index values among all per_train channels + channels = get_channels(config_dataframe.channels, ["per_train"]) + all_keys = Index([]) + for channel in channels: + all_keys = all_keys.append(df.get_dataset_array(channel, slice_=True)[0]) + assert result_df.shape[0] == len(all_keys.unique()) + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # all pulseIds and electronIds should be 0 + assert np.all(result_df.index.get_level_values("pulseId") == 0) + assert np.all(result_df.index.get_level_values("electronId") == 0) + + channel = "dldAux" + key, data = df.get_dataset_array(channel, slice_=True) + + # Check if the subchannels are correctly sliced into the dataframe + # The values are stored in DLD which is a 2D array + # The subchannels are stored in the second dimension + # Only index amount of values are stored in the first dimension, the rest are NaNs + # hence the slicing + aux_channels = config_dataframe.channels["dldAux"].dldAuxChannels + for aux_ch_name in aux_channels: + aux_ch = aux_channels[aux_ch_name] + np.all(df.df_train[aux_ch_name].dropna().values == data[: key.size, aux_ch.slice]) + + assert result_df.index.is_unique + + +def test_group_name_not_in_h5(config_dataframe, h5_file): + """Test ValueError when the group_name for a channel does not exist in the H5 file.""" + channel = "dldPosX" + config = config_dataframe + config.channels.get(channel).index_key = "foo" + df = FlashDataFrameCreator(config, h5_file) + with pytest.raises(KeyError): + df.df_electron + + +# def test_create_dataframe_per_file(config_dataframe, h5_file): +# """Test the creation of pandas DataFrames for a given file.""" +# df = FlashDataFrameCreator(config_dataframe, h5_file) +# result_df = df.df + +# # Check that the result_df is a DataFrame and has the correct shape +# assert isinstance(result_df, DataFrame) +# all_keys = df.df_train.index.append(df.df_electron.index).append(df.df_pulse.index) +# all_keys = all_keys.unique() +# assert result_df.shape[0] == len(all_keys.unique()) diff --git a/tests/loader/flash/test_flash_loader.py b/tests/loader/flash/test_flash_loader.py index edff997e..dc3a8008 100644 --- a/tests/loader/flash/test_flash_loader.py +++ b/tests/loader/flash/test_flash_loader.py @@ -4,9 +4,10 @@ from pathlib import Path from typing import Literal +import pandas as pd import pytest +from pydantic_core import ValidationError -from sed.core.config import parse_config from sed.loader.flash.loader import FlashLoader package_dir = os.path.dirname(find_spec("sed").origin) @@ -14,80 +15,12 @@ H5_PATH = "FLASH1_USER3_stream_2_run43878_file1_20230130T153807.1.h5" -@pytest.fixture(name="config_file") -def fixture_config_file() -> dict: - """Fixture providing a configuration file for FlashLoader tests. - - Returns: - dict: The parsed configuration file. - """ - return parse_config(config_path) - - -def test_get_channels_by_format(config_file: dict) -> None: - """ - Test function to verify the 'get_channels' method in FlashLoader class for - retrieving channels based on formats and index inclusion. - """ - # Initialize the FlashLoader instance with the given config_file. - fl = FlashLoader(config_file) - - # Define expected channels for each format. - electron_channels = ["dldPosX", "dldPosY", "dldTimeSteps"] - pulse_channels = [ - "sampleBias", - "tofVoltage", - "extractorVoltage", - "extractorCurrent", - "cryoTemperature", - "sampleTemperature", - "dldTimeBinSize", - "gmdTunnel", - ] - train_channels = ["timeStamp", "delayStage"] - index_channels = ["trainId", "pulseId", "electronId"] - - # Call get_channels method with different format options. - - # Request channels for 'per_electron' format using a list. - format_electron = fl.get_channels(["per_electron"]) - - # Request channels for 'per_pulse' format using a string. - format_pulse = fl.get_channels("per_pulse") - - # Request channels for 'per_train' format using a list. - format_train = fl.get_channels(["per_train"]) - - # Request channels for 'all' formats using a list. - format_all = fl.get_channels(["all"]) - - # Request index channels only. - format_index = fl.get_channels(index=True) - - # Request 'per_electron' format and include index channels. - format_index_electron = fl.get_channels(["per_electron"], index=True) - - # Request 'all' formats and include index channels. - format_all_index = fl.get_channels(["all"], index=True) - - # Assert that the obtained channels match the expected channels. - assert set(electron_channels) == set(format_electron) - assert set(pulse_channels) == set(format_pulse) - assert set(train_channels) == set(format_train) - assert set(electron_channels + pulse_channels + train_channels) == set(format_all) - assert set(index_channels) == set(format_index) - assert set(index_channels + electron_channels) == set(format_index_electron) - assert set(index_channels + electron_channels + pulse_channels + train_channels) == set( - format_all_index, - ) - - @pytest.mark.parametrize( "sub_dir", ["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"], ) def test_initialize_paths( - config_file: dict, + config_raw, fs, sub_dir: Literal["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"], ) -> None: @@ -98,15 +31,15 @@ def test_initialize_paths( fs: A fixture for a fake file system. sub_dir (Literal["online-0/fl1user3/", "express-0/fl1user3/", "FL1USER3/"]): Sub-directory. """ - config = config_file - del config["core"]["paths"] - config["core"]["beamtime_id"] = "12345678" - config["core"]["year"] = "2000" + config_dict = config_raw + del config_dict["core"]["paths"] + config_dict["core"]["beamtime_id"] = "12345678" + config_dict["core"]["year"] = "2000" # Find base path of beamline from config. Here, we use pg2 - base_path = config["dataframe"]["beamtime_dir"]["pg2"] + base_path = config_dict["dataframe"]["beamtime_dir"]["pg2"] expected_path = ( - Path(base_path) / config["core"]["year"] / "data" / config["core"]["beamtime_id"] + Path(base_path) / config_dict["core"]["year"] / "data" / config_dict["core"]["beamtime_id"] ) # Create expected paths expected_raw_path = expected_path / "raw" / "hdf" / sub_dir @@ -117,109 +50,36 @@ def test_initialize_paths( fs.create_dir(expected_processed_path) # Instance of class with correct config and call initialize_paths - fl = FlashLoader(config=config) - data_raw_dir, data_parquet_dir = fl.initialize_paths() + # config_alt = FlashLoaderConfig(**config_dict) + fl = FlashLoader(config=config_dict) + data_raw_dir = fl.config.core.paths.data_raw_dir + data_parquet_dir = fl.config.core.paths.data_parquet_dir - assert expected_raw_path == data_raw_dir[0] + assert expected_raw_path == data_raw_dir assert expected_processed_path == data_parquet_dir + # remove breamtimeid, year and daq from config to raise error + del config_dict["core"]["beamtime_id"] + with pytest.raises(ValidationError) as e: + fl = FlashLoader(config=config_dict) -def test_initialize_paths_filenotfound(config_file: dict) -> None: - """ - Test FileNotFoundError during the initialization of paths. - """ - # Test the FileNotFoundError - config = config_file - del config["core"]["paths"] - config["core"]["beamtime_id"] = "11111111" - config["core"]["year"] = "2000" - - # Instance of class with correct config and call initialize_paths - fl = FlashLoader(config=config) - with pytest.raises(FileNotFoundError): - _, _ = fl.initialize_paths() - - -def test_invalid_channel_format(config_file: dict) -> None: - """ - Test ValueError for an invalid channel format. - """ - config = config_file - config["dataframe"]["channels"]["dldPosX"]["format"] = "foo" - - fl = FlashLoader(config=config) - - with pytest.raises(ValueError): - fl.read_dataframe() - - -def test_group_name_not_in_h5(config_file: dict) -> None: - """ - Test ValueError when the group_name for a channel does not exist in the H5 file. - """ - config = config_file - config["dataframe"]["channels"]["dldPosX"]["group_name"] = "foo" - fl = FlashLoader(config=config) + error_messages = [error["msg"] for error in e.value.errors()] + assert ( + "Value error, Either 'paths' or 'beamtime_id' and 'year' must be provided." + in error_messages + ) - with pytest.raises(ValueError) as e: - fl.create_dataframe_per_file(Path(config["core"]["paths"]["data_raw_dir"] + H5_PATH)) - assert str(e.value.args[0]) == "The group_name for channel dldPosX does not exist." +def test_save_read_parquet_flash(config_raw): + """Test ParquetHandler save and read parquet""" + config_alt = config_raw + fl = FlashLoader(config=config_alt) + fl.config.core.paths.data_parquet_dir = fl.config.core.paths.data_parquet_dir.joinpath( + "_flash_save_read/", + ) + df1, _, _ = fl.read_dataframe(runs=[43878, 43879], save_parquet=True) + df2, _, _ = fl.read_dataframe(runs=[43878, 43879], load_parquet=True) -def test_buffer_schema_mismatch(config_file: dict) -> None: - """ - Test function to verify schema mismatch handling in the FlashLoader's 'read_dataframe' method. - - The test validates the error handling mechanism when the available channels do not match the - schema of the existing parquet files. - - Test Steps: - - Attempt to read a dataframe after adding a new channel 'gmdTunnel2' to the configuration. - - Check for an expected error related to the mismatch between available channels and schema. - - Force recreation of dataframe with the added channel, ensuring successful dataframe creation. - - Simulate a missing channel scenario by removing 'gmdTunnel2' from the configuration. - - Check for an error indicating a missing channel in the configuration. - - Clean up created buffer files after the test. - """ - fl = FlashLoader(config=config_file) - - # Read a dataframe for a specific run - fl.read_dataframe(runs=["43878"]) - - # Manipulate the configuration to introduce a new channel 'gmdTunnel2' - config = config_file - config["dataframe"]["channels"]["gmdTunnel2"] = { - "group_name": "/FL1/Photon Diagnostic/GMD/Pulse resolved energy/energy tunnel/", - "format": "per_pulse", - } - - # Reread the dataframe with the modified configuration, expecting a schema mismatch error - fl = FlashLoader(config=config) - with pytest.raises(ValueError) as e: - fl.read_dataframe(runs=["43878"]) - expected_error = e.value.args - - # Validate the specific error messages for schema mismatch - assert "The available channels do not match the schema of file" in expected_error[0] - assert expected_error[2] == "Missing in parquet: {'gmdTunnel2'}" - assert expected_error[4] == "Please check the configuration file or set force_recreate to True." - - # Force recreation of the dataframe, including the added channel 'gmdTunnel2' - fl.read_dataframe(runs=["43878"], force_recreate=True) - - # Remove 'gmdTunnel2' from the configuration to simulate a missing channel scenario - del config["dataframe"]["channels"]["gmdTunnel2"] - fl = FlashLoader(config=config) - with pytest.raises(ValueError) as e: - # Attempt to read the dataframe again to check for the missing channel error - fl.read_dataframe(runs=["43878"]) - - expected_error = e.value.args - # Check for the specific error message indicating a missing channel in the configuration - assert expected_error[3] == "Missing in config: {'gmdTunnel2'}" - - # Clean up created buffer files after the test - _, parquet_data_dir = fl.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) + # check if parquet read is same as parquet saved read correctly + pd.testing.assert_frame_equal(df1.compute().reset_index(drop=True), df2.compute()) diff --git a/tests/loader/sxp/conftest.py b/tests/loader/sxp/conftest.py new file mode 100644 index 00000000..a3976922 --- /dev/null +++ b/tests/loader/sxp/conftest.py @@ -0,0 +1,115 @@ +""" This module contains fixtures for the FEL module tests. +""" +import os +import shutil +from importlib.util import find_spec + +import h5py +import pytest + +from sed.core.config import parse_config +from sed.loader.fel.config_model import DataFrameConfig +from sed.loader.fel.config_model import LoaderConfig + +package_dir = os.path.dirname(find_spec("sed").origin) +config_path = os.path.join(package_dir, "../tests/data/loader/sxp/config.yaml") +H5_PATH = "RAW-R0016-DA03-S00000.h5" + + +@pytest.fixture(name="config_raw") +def fixture_config_raw_file() -> dict: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return parse_config(config_path) + + +@pytest.fixture(name="config") +def fixture_config_file(config_raw) -> LoaderConfig: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return LoaderConfig(**config_raw) + + +@pytest.fixture(name="config_dataframe") +def fixture_config_file_dataframe(config) -> DataFrameConfig: + """Fixture providing a configuration file for FlashLoader tests. + + Returns: + dict: The parsed configuration file. + """ + return config.dataframe + + +@pytest.fixture(name="h5_file") +def fixture_h5_file(): + """Fixture providing an open h5 file. + + Returns: + h5py.File: The open h5 file. + """ + return h5py.File(os.path.join(package_dir, f"../tests/data/loader/sxp/{H5_PATH}"), "r") + + +@pytest.fixture(name="h5_file_copy") +def fixture_h5_file_copy(tmp_path): + """Fixture providing a copy of an open h5 file. + + Returns: + h5py.File: The open h5 file copy. + """ + # Create a copy of the h5 file in a temporary directory + original_file_path = os.path.join(package_dir, f"../tests/data/loader/sxp/{H5_PATH}") + copy_file_path = tmp_path / "copy.h5" + shutil.copyfile(original_file_path, copy_file_path) + + # Open the copy in 'read-write' mode and return it + return h5py.File(copy_file_path, "r+") + + +# @pytest.fixture(name="pulserSignAdc_channel_array") +# def get_pulse_channel_from_h5(config_dataframe, h5_file): +# df = DataFrameCreator(config_dataframe) +# df.h5_file = h5_file +# train_id, pulse_id = df.get_dataset_array("pulserSignAdc") +# return train_id, pulse_id + + +# @pytest.fixture(name="multiindex_electron") +# def fixture_multi_index_electron(config_dataframe, h5_file): +# """Fixture providing multi index for electron resolved data""" +# df = DataFrameCreator(config_dataframe) +# df.h5_file = h5_file +# pulse_index, indexer = df.pulse_index(config_dataframe["ubid_offset"]) + +# return pulse_index, indexer + + +# @pytest.fixture(name="fake_data") +# def fake_data_electron(): +# # Creating manageable fake data, but not used currently +# num_trains = 5 +# max_pulse_id = 100 +# nan_threshold = 50 +# ubid_offset = 5 +# seed = 42 +# np.random.seed(seed) +# train_ids = np.arange(1600000000, 1600000000 + num_trains) +# fake_data = [] + +# for _ in train_ids: +# pulse_ids = [] +# while len(pulse_ids) < nan_threshold: +# random_pulse_ids = np.random.choice( +# np.arange(ubid_offset, nan_threshold), size=np.random.randint(0, 10)) +# pulse_ids = np.concatenate([pulse_ids, random_pulse_ids]) + +# pulse_ids = np.concatenate([pulse_ids, np.full(max_pulse_id-len(pulse_ids), np.nan)]) + +# fake_data.append(np.sort(pulse_ids)) +# return Series(train_ids, name="trainId"), np.array(fake_data), ubid_offset diff --git a/tests/loader/sxp/test_dataframe_creator.py b/tests/loader/sxp/test_dataframe_creator.py new file mode 100644 index 00000000..f78b6224 --- /dev/null +++ b/tests/loader/sxp/test_dataframe_creator.py @@ -0,0 +1,173 @@ +"""Tests for SXPDataFrameCreator functionality""" +import numpy as np +import pytest +from pandas import DataFrame +from pandas import Index + +from sed.loader.fel.utils import get_channels +from sed.loader.sxp.dataframe import SXPDataFrameCreator + + +def test_get_dataset_array(config_dataframe, h5_file): + """Test the creation of a h5py dataset for a given channel.""" + + df = SXPDataFrameCreator(config_dataframe, h5_file) + channel = "dldPosX" + max_hits = df._config.channels.get(channel).max_hits + + train_id, dset = df.get_dataset_array(channel) + # Check that the train_id and np_array have the correct shapes and types + assert isinstance(train_id, Index) + assert isinstance(dset, np.ndarray) + assert train_id.name == "trainId" + assert train_id.shape[0] == dset.shape[0] + assert dset.shape[1] == max_hits + + channel = "delayStage" + train_id, dset = df.get_dataset_array(channel) + assert train_id.shape[0] == dset.shape[0] + + +def test_empty_get_dataset_array(config_dataframe, h5_file, h5_file_copy): + """Test the method when given an empty dataset.""" + + channel = "delayStage" + df = SXPDataFrameCreator(config_dataframe, h5_file) + train_id, dset = df.get_dataset_array(channel) + + channel_index_key = "/INDEX/trainId" + empty_dataset_key = "/CONTROL/SCS_ILH_LAS/MDL/OPTICALDELAY_PP800/actualPosition/empty" + config_dataframe.channels.get(channel).index_key = channel_index_key + config_dataframe.channels.get(channel).dataset_key = empty_dataset_key + # Remove the 'group_name' key + del config_dataframe.channels.get(channel).group_name + + # create an empty dataset + h5_file_copy.create_dataset( + name=empty_dataset_key, + shape=(train_id.shape[0], 0), + ) + + df = SXPDataFrameCreator(config_dataframe, h5_file) + df.h5_file = h5_file_copy + train_id, dset_empty = df.get_dataset_array(channel) + + assert dset_empty.shape[0] == train_id.shape[0] + + +# def test_pulse_index(config_dataframe, h5_file): +# """Test the creation of the pulse index for electron resolved data""" + +# df = SXPDataFrameCreator(config_dataframe, h5_file) +# pulse_index, pulse_array = df.get_dataset_array("pulseId", slice_=True) +# index, indexer = df.pulse_index(config_dataframe.ubid_offset) +# # Check if the index_per_electron is a MultiIndex and has the correct levels +# assert isinstance(index, MultiIndex) +# assert set(index.names) == {"trainId", "pulseId", "electronId"} + +# # Check if the pulse_index has the correct number of elements +# # This should be the pulses without nan values +# pulse_rav = pulse_array.ravel() +# pulse_no_nan = pulse_rav[~np.isnan(pulse_rav)] +# assert len(index) == len(pulse_no_nan) + +# # Check if all pulseIds are correctly mapped to the index +# assert np.all( +# index.get_level_values("pulseId").values +# == (pulse_no_nan - config_dataframe.ubid_offset)[indexer], +# ) + +# assert np.all( +# index.get_level_values("electronId").values[:5] == [0, 1, 0, 1, 0], +# ) + +# assert np.all( +# index.get_level_values("electronId").values[-5:] == [1, 0, 1, 0, 1], +# ) + +# # check if all indexes are unique and monotonic increasing +# assert index.is_unique +# assert index.is_monotonic_increasing + + +# def test_df_electron(config_dataframe, h5_file): +# """Test the creation of a pandas DataFrame for a channel of type [per electron].""" +# df = SXPDataFrameCreator(config_dataframe, h5_file) + +# result_df = df.df_electron + +# # check index levels +# assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + +# # check that there are no nan values in the dataframe +# assert ~result_df.isnull().values.any() + +# # Check that the values are dropped for pulseId index below 0 (ubid_offset) +# print( +# np.all( +# result_df.values[:5] +# != np.array( +# [ +# [556.0, 731.0, 42888.0], +# [549.0, 737.0, 42881.0], +# [671.0, 577.0, 39181.0], +# [671.0, 579.0, 39196.0], +# [714.0, 859.0, 37530.0], +# ], +# ), +# ), +# ) +# assert np.all(result_df.index.get_level_values("pulseId") >= 0) +# assert isinstance(result_df, DataFrame) + +# assert result_df.index.is_unique + +# # check that dataframe contains all subchannels +# assert np.all( +# set(result_df.columns) == set(get_channels(config_dataframe.channels, ["per_electron"])), +# ) + + +def test_create_dataframe_per_train(config_dataframe, h5_file): + """Test the creation of a pandas DataFrame for a channel of type [per train].""" + df = SXPDataFrameCreator(config_dataframe, h5_file) + result_df = df.df_train + + channel = "delayStage" + _, data = df.get_dataset_array(channel) + + # Check that the result_df is a DataFrame and has the correct shape + assert isinstance(result_df, DataFrame) + + # check that all values are in the df for delayStage + np.all(result_df[channel].dropna() == data[()]) + + # check that dataframe contains all channels + assert np.all( + set(result_df.columns) + == set(get_channels(config_dataframe.channels, ["per_train"], extend_aux=True)), + ) + + # find unique index values among all per_train channels + channels = get_channels(config_dataframe.channels, ["per_train"]) + all_keys = Index([]) + for channel in channels: + all_keys = all_keys.append(df.get_dataset_array(channel)[0]) + assert result_df.shape[0] == len(all_keys.unique()) + + # check index levels + assert set(result_df.index.names) == {"trainId", "pulseId", "electronId"} + + # all pulseIds and electronIds should be 0 + assert np.all(result_df.index.get_level_values("pulseId") == 0) + assert np.all(result_df.index.get_level_values("electronId") == 0) + + +def test_group_name_not_in_h5(config_dataframe, h5_file): + """Test ValueError when the group_name for a channel does not exist in the H5 file.""" + channel = "delayStage" + config = config_dataframe + config.channels.get(channel).index_key = "foo" + df = SXPDataFrameCreator(config, h5_file) + with pytest.raises(KeyError): + df.df_train diff --git a/tests/loader/sxp/test_sxp_loader.py b/tests/loader/sxp/test_sxp_loader.py index 3332f231..2ea0e82d 100644 --- a/tests/loader/sxp/test_sxp_loader.py +++ b/tests/loader/sxp/test_sxp_loader.py @@ -1,87 +1,18 @@ # pylint: disable=duplicate-code """Tests for SXPLoader functionality""" -import os -from importlib.util import find_spec from pathlib import Path -from typing import List -import pytest - -from sed.core.config import parse_config from sed.loader.sxp.loader import SXPLoader -package_dir = os.path.dirname(find_spec("sed").origin) -config_path = os.path.join(package_dir, "../tests/data/loader/sxp/config.yaml") -H5_PATH = "RAW-R0016-DA03-S00000.h5" - - -@pytest.fixture(name="config_file") -def fixture_config_file() -> dict: - """Fixture providing a configuration file for SXPLoader tests. - - Returns: - dict: The parsed configuration file. - """ - return parse_config(config=config_path, folder_config={}, user_config={}, system_config={}) - - -def test_get_channels_by_format(config_file: dict) -> None: - """ - Test function to verify the 'get_channels' method in SXPLoader class for - retrieving channels based on formats and index inclusion. - """ - # Initialize the SXPLoader instance with the given config_file. - sl = SXPLoader(config_file) - - # Define expected channels for each format. - electron_channels = ["dldPosX", "dldPosY", "dldTimeSteps"] - pulse_channels: List[str] = [] - train_channels = ["timeStamp", "delayStage"] - index_channels = ["trainId", "pulseId", "electronId"] - - # Call get_channels method with different format options. - - # Request channels for 'per_electron' format using a list. - format_electron = sl.get_channels(["per_electron"]) - - # Request channels for 'per_pulse' format using a string. - format_pulse = sl.get_channels("per_pulse") - - # Request channels for 'per_train' format using a list. - format_train = sl.get_channels(["per_train"]) - - # Request channels for 'all' formats using a list. - format_all = sl.get_channels(["all"]) - # Request index channels only. - format_index = sl.get_channels(index=True) - - # Request 'per_electron' format and include index channels. - format_index_electron = sl.get_channels(["per_electron"], index=True) - - # Request 'all' formats and include index channels. - format_all_index = sl.get_channels(["all"], index=True) - - # Assert that the obtained channels match the expected channels. - assert set(electron_channels) == set(format_electron) - assert set(pulse_channels) == set(format_pulse) - assert set(train_channels) == set(format_train) - assert set(electron_channels + pulse_channels + train_channels) == set(format_all) - assert set(index_channels) == set(format_index) - assert set(index_channels + electron_channels) == set(format_index_electron) - assert set(index_channels + electron_channels + pulse_channels + train_channels) == set( - format_all_index, - ) - - -def test_initialize_paths(config_file: dict, fs) -> None: +def test_initialize_paths(config_raw: dict, fs) -> None: """ Test the initialization of paths based on the configuration and directory structures. Args: fs: A fixture for a fake file system. """ - config = config_file + config = config_raw del config["core"]["paths"] config["core"]["beamtime_id"] = "12345678" config["core"]["year"] = "2000" @@ -99,116 +30,8 @@ def test_initialize_paths(config_file: dict, fs) -> None: # Instance of class with correct config and call initialize_paths sl = SXPLoader(config=config) - data_raw_dir, data_parquet_dir = sl.initialize_paths() + data_raw_dir = sl.config.core.paths.data_raw_dir + data_parquet_dir = sl.config.core.paths.data_parquet_dir - assert expected_raw_path == data_raw_dir[0] + assert expected_raw_path == data_raw_dir assert expected_processed_path == data_parquet_dir - - -def test_initialize_paths_filenotfound(config_file: dict): - """ - Test FileNotFoundError during the initialization of paths. - """ - # Test the FileNotFoundError - config = config_file - del config["core"]["paths"] - config["core"]["beamtime_id"] = "11111111" - config["core"]["year"] = "2000" - - # Instance of class with correct config and call initialize_paths - sl = SXPLoader(config=config) - with pytest.raises(FileNotFoundError): - _, _ = sl.initialize_paths() - - -def test_invalid_channel_format(config_file: dict): - """ - Test ValueError for an invalid channel format. - """ - config = config_file - config["dataframe"]["channels"]["dldPosX"]["format"] = "foo" - - sl = SXPLoader(config=config) - - with pytest.raises(ValueError): - sl.read_dataframe() - - -@pytest.mark.parametrize( - "key_type", - ["dataset_key", "index_key"], -) -def test_data_keys_not_in_h5(config_file: dict, key_type: str): - """Test ValueError when the dataset_key or index_key for a channel does not exist in the H5 - file. - - Args: - key_type (str): Key type to check - """ - config = config_file - config["dataframe"]["channels"]["dldPosX"][key_type] = "foo" - sl = SXPLoader(config=config) - - with pytest.raises(ValueError) as e: - sl.create_dataframe_per_file(config["core"]["paths"]["data_raw_dir"] + H5_PATH) - - assert str(e.value.args[0]) == f"The {key_type} for channel dldPosX does not exist." - - -def test_buffer_schema_mismatch(config_file: dict): - """ - Test function to verify schema mismatch handling in the SXPLoader's 'read_dataframe' method. - - The test validates the error handling mechanism when the available channels do not match the - schema of the existing parquet files. - - Test Steps: - - Attempt to read a dataframe after adding a new channel 'delayStage2' to the configuration. - - Check for an expected error related to the mismatch between available channels and schema. - - Force recreation of dataframe with the added channel, ensuring successful dataframe creation. - - Simulate a missing channel scenario by removing 'delayStage2' from the configuration. - - Check for an error indicating a missing channel in the configuration. - - Clean up created buffer files after the test. - """ - sl = SXPLoader(config=config_file) - - # Read a dataframe for a specific run - sl.read_dataframe(runs=["0016"]) - - # Manipulate the configuration to introduce a new channel 'delayStage2' - config = config_file - config["dataframe"]["channels"]["delayStage2"] = { - "format": "per_train", - "dataset_key": "/CONTROL/SCS_ILH_LAS/MDL/OPTICALDELAY_PP800/actualPosition/value", - "index_key": "/INDEX/trainId", - } - - # Reread the dataframe with the modified configuration, expecting a schema mismatch error - sl = SXPLoader(config=config) - with pytest.raises(ValueError) as e: - sl.read_dataframe(runs=["0016"]) - expected_error = e.value.args - - # Validate the specific error messages for schema mismatch - assert "The available channels do not match the schema of file" in expected_error[0] - assert expected_error[2] == "Missing in parquet: {'delayStage2'}" - assert expected_error[4] == "Please check the configuration file or set force_recreate to True." - - # Force recreation of the dataframe, including the added channel 'delayStage2' - sl.read_dataframe(runs=["0016"], force_recreate=True) - - # Remove 'delayStage2' from the configuration to simulate a missing channel scenario - del config["dataframe"]["channels"]["delayStage2"] - sl = SXPLoader(config=config) - with pytest.raises(ValueError) as e: - # Attempt to read the dataframe again to check for the missing channel error - sl.read_dataframe(runs=["0016"]) - - expected_error = e.value.args - # Check for the specific error message indicating a missing channel in the configuration - assert expected_error[3] == "Missing in config: {'delayStage2'}" - - # Clean up created buffer files after the test - _, parquet_data_dir = sl.initialize_paths() - for file in os.listdir(Path(parquet_data_dir, "buffer")): - os.remove(Path(parquet_data_dir, "buffer", file)) diff --git a/tests/loader/test_loaders.py b/tests/loader/test_loaders.py index f638ba0d..07372959 100644 --- a/tests/loader/test_loaders.py +++ b/tests/loader/test_loaders.py @@ -163,7 +163,7 @@ def test_has_correct_read_dataframe_func(loader: BaseLoader, read_type: str) -> if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) @@ -196,7 +196,7 @@ def test_timed_dataframe(loader: BaseLoader) -> None: if loaded_timed_dataframe is None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) pytest.skip("Not implemented") @@ -206,7 +206,7 @@ def test_timed_dataframe(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) @@ -240,7 +240,7 @@ def test_get_count_rate(loader: BaseLoader) -> None: if loaded_time is None and loaded_countrate is None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) pytest.skip("Not implemented") @@ -251,7 +251,7 @@ def test_get_count_rate(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) @@ -285,7 +285,7 @@ def test_get_elapsed_time(loader: BaseLoader) -> None: if elapsed_time is None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) pytest.skip("Not implemented") @@ -296,7 +296,7 @@ def test_get_elapsed_time(loader: BaseLoader) -> None: if loader.__name__ in {"flash", "sxp"}: loader = cast(FlashLoader, loader) - _, parquet_data_dir = loader.initialize_paths() + parquet_data_dir = loader.config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file)) diff --git a/tests/test_processor.py b/tests/test_processor.py index 98bdba2d..09efa61c 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -629,7 +629,7 @@ def test_align_dld_sectors() -> None: np.testing.assert_allclose(tof_ref_array, tof_aligned_array + sector_delays[:, np.newaxis]) # cleanup flash inermediaries - _, parquet_data_dir = cast(FlashLoader, processor.loader).initialize_paths() + parquet_data_dir = cast(FlashLoader, processor.loader).config.core.paths.data_parquet_dir for file in os.listdir(Path(parquet_data_dir, "buffer")): os.remove(Path(parquet_data_dir, "buffer", file))