Skip to content

Commit

Permalink
simplify NaN filling in FlashLoader and make it lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
zain-sohail committed Oct 7, 2023
1 parent 8f49f40 commit 6464926
Showing 1 changed file with 9 additions and 59 deletions.
68 changes: 9 additions & 59 deletions sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
sed funtionality.
"""
from functools import reduce
from itertools import compress
from pathlib import Path
from typing import List
from typing import Sequence
Expand Down Expand Up @@ -676,60 +675,6 @@ def buffer_file_handler(self, data_parquet_dir: Path, detector: str):

return h5_filenames, parquet_filenames

def fill_na(
self,
dataframes: List[dd.DataFrame],
) -> dd.DataFrame:
"""
Fill NaN values in the given dataframes using intrafile forward filling.
Args:
dataframes (List[dd.DataFrame]): List of dataframes to fill NaN values.
Returns:
dd.DataFrame: Concatenated dataframe with filled NaN values.
Notes:
This method is specific to the flash data structure and is used to fill NaN values in
certain channels that only store information at a lower frequency. The low frequency
channels are exploded to match the dimensions of higher frequency channels, but they
may contain NaNs in the other columns. This method fills the NaNs for the specific
channels (per_pulse and per_train).
"""
# Channels to fill NaN values
channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"])

# Fill NaN values within each dataframe
for i, _ in enumerate(dataframes):
dataframes[i][channels] = dataframes[i][channels].fillna(
method="ffill",
)

# Forward fill between consecutive dataframes
for i in range(1, len(dataframes)):
# Select pulse channels from current dataframe
subset = dataframes[i][channels]
# Find columns with NaN values in the first row
is_null = subset.loc[0].isnull().values.compute()
# Execute if there are NaN values in the first row
if is_null.sum() > 0:
# Select channel names with only NaNs
channels_to_overwrite = list(compress(channels, is_null[0]))
# Get values for those channels from the previous dataframe
values = dataframes[i - 1][channels].tail(1).values[0]
# Create a dictionary to fill NaN values
fill_dict = dict(zip(channels, values))
fill_dict = {k: v for k, v in fill_dict.items() if k in channels_to_overwrite}
# Fill NaN values with the corresponding values from the
# previous dataframe
dataframes[i][channels_to_overwrite] = subset[channels_to_overwrite].fillna(
fill_dict,
)

# Concatenate the filled dataframes
return dd.concat(dataframes)

def parquet_handler(
self,
data_parquet_dir: Path,
Expand Down Expand Up @@ -786,11 +731,16 @@ def parquet_handler(
detector,
)

# Read all parquet files using dask and concatenate into one dataframe after filling
dataframe = self.fill_na(
[dd.read_parquet(file) for file in parquet_filenames],
)
# Read all parquet files into one dataframe using dask
dataframe = dd.read_parquet(parquet_filenames)

# Channels to fill NaN values
channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"])

# Fill NaN values
dataframe[channels] = dataframe[channels].ffill()

# Remove the NaNs from per_electron channels
dataframe = dataframe.dropna(
subset=self.get_channels_by_format(["per_electron"]),
)
Expand Down

0 comments on commit 6464926

Please sign in to comment.