Skip to content

Commit

Permalink
Merge pull request #164 from OpenCOMPES/flash-lazy
Browse files Browse the repository at this point in the history
Flash lazy
  • Loading branch information
steinnymir authored Oct 9, 2023
2 parents bdcbacc + 24b81c3 commit c05f829
Showing 1 changed file with 21 additions and 57 deletions.
78 changes: 21 additions & 57 deletions sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
sed funtionality.
"""
from functools import reduce
from itertools import compress
from pathlib import Path
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union

import dask.dataframe as dd
from dask.diagnostics import ProgressBar
import h5py
import numpy as np
from joblib import delayed
Expand Down Expand Up @@ -679,58 +679,6 @@ def buffer_file_handler(self, data_parquet_dir: Path, detector: str):

return h5_filenames, parquet_filenames

def fill_na(
self,
dataframes: List[dd.DataFrame],
) -> dd.DataFrame:
"""
Fill NaN values in the given dataframes using intrafile forward filling.
Args:
dataframes (List[dd.DataFrame]): List of dataframes to fill NaN values.
Returns:
dd.DataFrame: Concatenated dataframe with filled NaN values.
Notes:
This method is specific to the flash data structure and is used to fill NaN values in
certain channels that only store information at a lower frequency. The low frequency
channels are exploded to match the dimensions of higher frequency channels, but they
may contain NaNs in the other columns. This method fills the NaNs for the specific
channels (per_pulse and per_train).
"""
# Channels to fill NaN values
channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"])

# Fill NaN values within each dataframe
for i, _ in enumerate(dataframes):
dataframes[i][channels] = dataframes[i][channels].fillna(
method="ffill",
)

# Forward fill between consecutive dataframes
for i in tqdm(range(1, len(dataframes)), desc='Filling NaNs', leave=True, total=len(dataframes)-1, disable=len(dataframes)==1):
# Select pulse channels from current dataframe
subset = dataframes[i][channels]
# Find columns with NaN values in the first row
is_null = subset.loc[0].isnull().values.compute()
# Execute if there are NaN values in the first row
if is_null.sum() > 0:
# Select channel names with only NaNs
channels_to_overwrite = list(compress(channels, is_null[0]))
# Get values for those channels from the previous dataframe
values = dataframes[i - 1][channels].tail(1).values[0]
# Create a dictionary to fill NaN values
fill_dict = dict(zip(channels, values))
fill_dict = {k: v for k, v in fill_dict.items() if k in channels_to_overwrite}
# Fill NaN values with the corresponding values from the
# previous dataframe
dataframes[i][channels_to_overwrite] = subset[channels_to_overwrite].fillna(
fill_dict,
)
# Concatenate the filled dataframes
return dd.concat(dataframes)

def parquet_handler(
self,
Expand Down Expand Up @@ -788,11 +736,27 @@ def parquet_handler(
detector,
)

# Read all parquet files using dask and concatenate into one dataframe after filling
dataframe = self.fill_na(
[dd.read_parquet(file) for file in parquet_filenames],
)
# Read all parquet files into one dataframe using dask
dataframe = dd.read_parquet(parquet_filenames)

# Channels to fill NaN values
channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"])

# Define a custom function to forward fill specified columns
def forward_fill_partition(df):
df[channels] = df[channels].ffill()
return df

# calculate the number of rows in each partition
with ProgressBar():
print("Computing dataframe shape...")
nrows = dataframe.map_partitions(len).compute()
max_part_size = min(nrows)

# Use map_overlap to apply forward_fill_partition
dataframe = dataframe.map_overlap(forward_fill_partition, before=max_part_size+1, after=0)

# Remove the NaNs from per_electron channels
dataframe = dataframe.dropna(
subset=self.get_channels_by_format(["per_electron"]),
)
Expand Down

0 comments on commit c05f829

Please sign in to comment.