Skip to content

Commit

Permalink
Merge pull request #484 from OpenCOMPES/remove-invalid-channels
Browse files Browse the repository at this point in the history
FlashLoader: Remove invalid files by catching exception, filter out invalid pulses
  • Loading branch information
zain-sohail authored Aug 21, 2024
2 parents 779b594 + a890d5d commit 60e43e7
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 49 deletions.
53 changes: 46 additions & 7 deletions sed/loader/flash/buffer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from joblib import Parallel

from sed.core.dfops import forward_fill_lazy
from sed.core.logging import setup_logging
from sed.loader.flash.dataframe import DataFrameCreator
from sed.loader.flash.utils import get_channels
from sed.loader.flash.utils import get_dtypes
from sed.loader.flash.utils import InvalidFileError
from sed.loader.utils import get_parquet_metadata
from sed.loader.utils import split_dld_time_from_sector_id


DF_TYP = ["electron", "timed"]

logger = setup_logging(__name__)


class BufferFilePaths:
"""
Expand All @@ -33,7 +37,14 @@ class BufferFilePaths:
}
"""

def __init__(self, h5_paths: list[Path], folder: Path, suffix: str) -> None:
def __init__(
self,
config: dict,
h5_paths: list[Path],
folder: Path,
suffix: str,
remove_invalid_files: bool,
) -> None:
"""Initializes the BufferFilePaths.
Args:
Expand All @@ -45,8 +56,18 @@ def __init__(self, h5_paths: list[Path], folder: Path, suffix: str) -> None:
folder = folder / "buffer"
folder.mkdir(parents=True, exist_ok=True)

# a list of file sets containing the paths to the raw, electron and timed buffer files
self._file_paths = [
if remove_invalid_files:
h5_paths = self.remove_invalid_files(config, h5_paths)

self._file_paths = self._create_file_paths(h5_paths, folder, suffix)

def _create_file_paths(
self,
h5_paths: list[Path],
folder: Path,
suffix: str,
) -> list[dict[str, Path]]:
return [
{
"raw": h5_path,
**{typ: folder / f"{typ}_{h5_path.stem}{suffix}" for typ in DF_TYP},
Expand All @@ -71,6 +92,18 @@ def file_sets_to_process(self, force_recreate: bool = False) -> list[dict[str, P
return self._file_paths
return [file_set for file_set in self if any(not file_set[key].exists() for key in DF_TYP)]

def remove_invalid_files(self, config, h5_paths: list[Path]) -> list[Path]:
valid_h5_paths = []
for h5_path in h5_paths:
try:
dfc = DataFrameCreator(config_dataframe=config, h5_path=h5_path)
dfc.validate_channel_keys()
valid_h5_paths.append(h5_path)
except InvalidFileError as e:
logger.info(f"Skipping invalid file: {h5_path.stem}\n{e}")

return valid_h5_paths


class BufferHandler:
"""
Expand Down Expand Up @@ -157,6 +190,7 @@ def _save_buffer_file(self, paths: dict[str, Path]) -> None:
df_timed = df[self.fill_channels].loc[:, :, 0]
dtypes = get_dtypes(self._config, df_timed.columns.values)
df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"])
logger.debug(f"Processed {paths['raw'].stem}")

def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
"""
Expand All @@ -167,13 +201,12 @@ def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
debug (bool): Flag to enable debug mode, which serializes the creation.
"""
file_sets = self.fp.file_sets_to_process(force_recreate)
print(f"Reading files: {len(file_sets)} new files of {len(self.fp)} total.")
logger.info(f"Reading files: {len(file_sets)} new files of {len(self.fp)} total.")
n_cores = min(len(file_sets), self.n_cores)
if n_cores > 0:
if debug:
for file_set in file_sets:
self._save_buffer_file(file_set)
print(f"Processed {file_set['raw'].stem}")
else:
Parallel(n_jobs=n_cores, verbose=10)(
delayed(self._save_buffer_file)(file_set) for file_set in file_sets
Expand All @@ -191,6 +224,8 @@ def _get_dataframes(self) -> None:
it pulse resolved (no longer electron resolved). If time_index is True,
the timeIndex is calculated and set as the index (slow operation).
"""
if not self.fp:
raise FileNotFoundError("Buffer files do not exist.")
# Loop over the electron and timed dataframes
file_stats = {}
filling = {}
Expand Down Expand Up @@ -219,7 +254,10 @@ def _get_dataframes(self) -> None:
self.df[typ] = df
self.metadata.update({"file_statistics": file_stats, "filling": filling})
# Correct the 3-bit shift which encodes the detector ID in the 8s time
if self._config.get("split_sector_id_from_dld_time", False):
if (
self._config.get("split_sector_id_from_dld_time", False)
and self._config.get("tof_column", "dldTimeSteps") in self.df["electron"].columns
):
self.df["electron"], meta = split_dld_time_from_sector_id(
self.df["electron"],
config=self._config,
Expand All @@ -233,6 +271,7 @@ def process_and_load_dataframe(
force_recreate: bool = False,
suffix: str = "",
debug: bool = False,
remove_invalid_files: bool = False,
) -> tuple[dd.DataFrame, dd.DataFrame]:
"""
Runs the buffer file creation process.
Expand All @@ -249,7 +288,7 @@ def process_and_load_dataframe(
Returns:
Tuple[dd.DataFrame, dd.DataFrame]: The electron and timed dataframes.
"""
self.fp = BufferFilePaths(h5_paths, folder, suffix)
self.fp = BufferFilePaths(self._config, h5_paths, folder, suffix, remove_invalid_files)

if not force_recreate:
schema_set = set(
Expand Down
43 changes: 18 additions & 25 deletions sed/loader/flash/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd

from sed.loader.flash.utils import get_channels
from sed.loader.flash.utils import InvalidFileError


class DataFrameCreator:
Expand Down Expand Up @@ -148,16 +149,16 @@ def df_electron(self) -> pd.DataFrame:
Returns:
pd.DataFrame: The pandas DataFrame for the 'per_electron' channel's data.
"""
offset = self._config.get("ubid_offset", 5) # 5 is the default value
# Here we get the multi-index and the indexer to sort the data
index, indexer = self.pulse_index(offset)

# Get the relevant channels and their slice index
channels = get_channels(self._config, "per_electron")
if channels == []:
return pd.DataFrame()
slice_index = [self._config["channels"][channel].get("slice", None) for channel in channels]

offset = self._config.get("ubid_offset", 5) # 5 is the default value
# Here we get the multi-index and the indexer to sort the data
index, indexer = self.pulse_index(offset)

# First checking if dataset keys are the same for all channels
# because DLD at FLASH stores all channels in the same h5 dataset
dataset_keys = [self.get_index_dataset_key(channel)[1] for channel in channels]
Expand All @@ -180,20 +181,8 @@ def df_electron(self) -> pd.DataFrame:
}
dataframe = pd.concat(series, axis=1)

# after offset, the negative pulse values are dropped as they are not valid
drop_vals = np.arange(-offset, 0)

# Few things happen here:
# Drop all NaN values like while creating the multiindex
# if necessary, the data is sorted with [indexer]
# pd.MultiIndex is set
# Finally, the offset values are dropped
return (
dataframe.dropna()
.iloc[indexer]
.set_index(index)
.drop(index=drop_vals, level="pulseId", errors="ignore")
)
# NaN values dropped, data sorted with [indexer] if necessary, and the MultiIndex is set
return dataframe.dropna().iloc[indexer].set_index(index)

@property
def df_pulse(self) -> pd.DataFrame:
Expand Down Expand Up @@ -279,17 +268,19 @@ def df_train(self) -> pd.DataFrame:

def validate_channel_keys(self) -> None:
"""
Validates if the index and dataset keys for all channels in config exist in the h5 file.
Validates if the index and dataset keys for all channels in the config exist in the h5 file.
Raises:
KeyError: If the index or dataset keys do not exist in the file.
InvalidFileError: If the index or dataset keys are missing in the h5 file.
"""
invalid_channels = []
for channel in self._config["channels"]:
index_key, dataset_key = self.get_index_dataset_key(channel)
if index_key not in self.h5_file:
raise KeyError(f"pd.Index key '{index_key}' doesn't exist in the file.")
if dataset_key not in self.h5_file:
raise KeyError(f"Dataset key '{dataset_key}' doesn't exist in the file.")
if index_key not in self.h5_file or dataset_key not in self.h5_file:
invalid_channels.append(channel)

if invalid_channels:
raise InvalidFileError(invalid_channels)

@property
def df(self) -> pd.DataFrame:
Expand All @@ -304,4 +295,6 @@ def df(self) -> pd.DataFrame:
self.validate_channel_keys()
# been tested with merge, join and concat
# concat offers best performance, almost 3 times faster
return pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index()
df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index()
# all the negative pulse values are dropped as they are invalid
return df[df.index.get_level_values("pulseId") >= 0]
35 changes: 26 additions & 9 deletions sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ def get_elapsed_time(self, fids: Sequence[int] = None, **kwds) -> float | list[f
Args:
fids (Sequence[int]): A sequence of file IDs. Defaults to all files.
**kwds:
- runs: A sequence of run IDs. Takes precedence over fids.
- aggregate: Whether to return the sum of the elapsed times across
Keyword Args:
runs: A sequence of run IDs. Takes precedence over fids.
aggregate: Whether to return the sum of the elapsed times across
the specified files or the elapsed time for each file. Defaults to True.
Returns:
Expand Down Expand Up @@ -269,10 +270,6 @@ def read_dataframe(
ftype: str = "h5",
metadata: dict = {},
collect_metadata: bool = False,
detector: str = "",
force_recreate: bool = False,
processed_dir: str | Path = None,
debug: bool = False,
**kwds,
) -> tuple[dd.DataFrame, dd.DataFrame, dict]:
"""
Expand All @@ -289,7 +286,17 @@ def read_dataframe(
ftype (str, optional): The file extension type. Defaults to "h5".
metadata (dict, optional): Additional metadata. Defaults to None.
collect_metadata (bool, optional): Whether to collect metadata. Defaults to False.
**kwds: Additional keyword arguments passed to ``parse_metadata``.
Keyword Args:
detector (str, optional): The detector to use. Defaults to "".
force_recreate (bool, optional): Whether to force recreation of the buffer files.
Defaults to False.
processed_dir (str, optional): The directory to save the processed files.
Defaults to None.
debug (bool, optional): Whether to run buffer creation in serial. Defaults to False.
remove_invalid_files (bool, optional): Whether to exclude invalid files.
Defaults to False.
scicat_token (str, optional): The scicat token to use for fetching metadata.
Returns:
tuple[dd.DataFrame, dd.DataFrame, dict]: A tuple containing the concatenated DataFrame
Expand All @@ -299,6 +306,15 @@ def read_dataframe(
ValueError: If neither 'runs' nor 'files'/'raw_dir' is provided.
FileNotFoundError: If the conversion fails for some files or no data is available.
"""
detector = kwds.pop("detector", "")
force_recreate = kwds.pop("force_recreate", False)
processed_dir = kwds.pop("processed_dir", None)
debug = kwds.pop("debug", False)
remove_invalid_files = kwds.pop("remove_invalid_files", False)
scicat_token = kwds.pop("scicat_token", None)

if len(kwds) > 0:
raise ValueError(f"Unexpected keyword arguments: {kwds.keys()}")
t0 = time.time()

self._initialize_dirs()
Expand Down Expand Up @@ -341,12 +357,13 @@ def read_dataframe(
force_recreate=force_recreate,
suffix=detector,
debug=debug,
remove_invalid_files=remove_invalid_files,
)

if self.instrument == "wespe":
df, df_timed = wespe_convert(df, df_timed)

self.metadata.update(self.parse_metadata(**kwds) if collect_metadata else {})
self.metadata.update(self.parse_metadata(scicat_token) if collect_metadata else {})
self.metadata.update(bh.metadata)

print(f"loading complete in {time.time() - t0: .2f} s")
Expand Down
16 changes: 14 additions & 2 deletions sed/loader/flash/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def get_channels(

# Get the available channels excluding 'pulseId'.
available_channels = list(channel_dict.keys())
# raises error if not available, but necessary for pulse_index
available_channels.remove(PULSE_ALIAS)
# pulse alias is an index and should not be included in the list of channels.
if PULSE_ALIAS in available_channels:
available_channels.remove(PULSE_ALIAS)

for format_ in formats:
# Gather channels based on the specified format(s).
Expand Down Expand Up @@ -108,3 +109,14 @@ def get_dtypes(config_dataframe: dict, df_cols: list) -> dict:
except KeyError:
dtypes[channel] = None
return dtypes


class InvalidFileError(Exception):
"""Raised when an H5 file is invalid due to missing keys defined in the config."""

def __init__(self, invalid_channels: list[str]):
self.invalid_channels = invalid_channels
super().__init__(
f"Channels not in file: {', '.join(invalid_channels)}. "
"If you are using the loader, set 'remove_invalid_files' to True to ignore these files",
)
16 changes: 16 additions & 0 deletions tests/loader/flash/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ def fixture_h5_file_copy(tmp_path: Path) -> h5py.File:
return h5py.File(copy_file_path, "r+")


@pytest.fixture(name="h5_file2_copy")
def fixture_h5_file2_copy(tmp_path: Path) -> h5py.File:
"""Fixture providing a copy of an open h5 file.
Returns:
h5py.File: The open h5 file copy.
"""
# Create a copy of the h5 file in a temporary directory
original_file_path = os.path.join(package_dir, f"../tests/data/loader/flash/{H5_PATHS[1]}")
copy_file_path = tmp_path / "copy2.h5"
shutil.copyfile(original_file_path, copy_file_path)

# Open the copy in 'read-write' mode and return it
return h5py.File(copy_file_path, "r+")


@pytest.fixture(name="h5_paths")
def fixture_h5_paths() -> list[Path]:
"""Fixture providing a list of h5 file paths.
Expand Down
Loading

0 comments on commit 60e43e7

Please sign in to comment.