From 646492601fcf63fbfad598b5115af92789eb9a66 Mon Sep 17 00:00:00 2001 From: "M. Zain Sohail" Date: Sat, 7 Oct 2023 16:10:07 +0200 Subject: [PATCH 1/3] simplify NaN filling in FlashLoader and make it lazy --- sed/loader/flash/loader.py | 68 +++++--------------------------------- 1 file changed, 9 insertions(+), 59 deletions(-) diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index f5f5c803..59c68042 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -8,7 +8,6 @@ sed funtionality. """ from functools import reduce -from itertools import compress from pathlib import Path from typing import List from typing import Sequence @@ -676,60 +675,6 @@ def buffer_file_handler(self, data_parquet_dir: Path, detector: str): return h5_filenames, parquet_filenames - def fill_na( - self, - dataframes: List[dd.DataFrame], - ) -> dd.DataFrame: - """ - Fill NaN values in the given dataframes using intrafile forward filling. - - Args: - dataframes (List[dd.DataFrame]): List of dataframes to fill NaN values. - - Returns: - dd.DataFrame: Concatenated dataframe with filled NaN values. - - Notes: - This method is specific to the flash data structure and is used to fill NaN values in - certain channels that only store information at a lower frequency. The low frequency - channels are exploded to match the dimensions of higher frequency channels, but they - may contain NaNs in the other columns. This method fills the NaNs for the specific - channels (per_pulse and per_train). - - """ - # Channels to fill NaN values - channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"]) - - # Fill NaN values within each dataframe - for i, _ in enumerate(dataframes): - dataframes[i][channels] = dataframes[i][channels].fillna( - method="ffill", - ) - - # Forward fill between consecutive dataframes - for i in range(1, len(dataframes)): - # Select pulse channels from current dataframe - subset = dataframes[i][channels] - # Find columns with NaN values in the first row - is_null = subset.loc[0].isnull().values.compute() - # Execute if there are NaN values in the first row - if is_null.sum() > 0: - # Select channel names with only NaNs - channels_to_overwrite = list(compress(channels, is_null[0])) - # Get values for those channels from the previous dataframe - values = dataframes[i - 1][channels].tail(1).values[0] - # Create a dictionary to fill NaN values - fill_dict = dict(zip(channels, values)) - fill_dict = {k: v for k, v in fill_dict.items() if k in channels_to_overwrite} - # Fill NaN values with the corresponding values from the - # previous dataframe - dataframes[i][channels_to_overwrite] = subset[channels_to_overwrite].fillna( - fill_dict, - ) - - # Concatenate the filled dataframes - return dd.concat(dataframes) - def parquet_handler( self, data_parquet_dir: Path, @@ -786,11 +731,16 @@ def parquet_handler( detector, ) - # Read all parquet files using dask and concatenate into one dataframe after filling - dataframe = self.fill_na( - [dd.read_parquet(file) for file in parquet_filenames], - ) + # Read all parquet files into one dataframe using dask + dataframe = dd.read_parquet(parquet_filenames) + + # Channels to fill NaN values + channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"]) + + # Fill NaN values + dataframe[channels] = dataframe[channels].ffill() + # Remove the NaNs from per_electron channels dataframe = dataframe.dropna( subset=self.get_channels_by_format(["per_electron"]), ) From 74df4e5ae4efab380dd984b0d2ae588ad3cfd80a Mon Sep 17 00:00:00 2001 From: "M. Zain Sohail" Date: Sat, 7 Oct 2023 18:07:41 +0200 Subject: [PATCH 2/3] add the map_overlap method to do intrafile filling --- sed/loader/flash/loader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 59c68042..508f8333 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -737,8 +737,13 @@ def parquet_handler( # Channels to fill NaN values channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"]) - # Fill NaN values - dataframe[channels] = dataframe[channels].ffill() + # Define a custom function to forward fill specified columns + def forward_fill_partition(df): + df[channels] = df[channels].ffill() + return df + + # Use map_overlap to apply forward_fill_partition + dataframe = dataframe.map_overlap(forward_fill_partition, before=0, after=1) # Remove the NaNs from per_electron channels dataframe = dataframe.dropna( From 5b69796569ee7fb19bc7bcc407aff8c34fe5398b Mon Sep 17 00:00:00 2001 From: Steinn Ymir Agustsson Date: Mon, 9 Oct 2023 14:08:02 +0200 Subject: [PATCH 3/3] corrected map_overlap lenght to minimum part. size --- sed/loader/flash/loader.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 508f8333..7be08abe 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -15,6 +15,7 @@ from typing import Union import dask.dataframe as dd +from dask.diagnostics import ProgressBar import h5py import numpy as np from joblib import delayed @@ -742,8 +743,14 @@ def forward_fill_partition(df): df[channels] = df[channels].ffill() return df + # calculate the number of rows in each partition + with ProgressBar(): + print("Computing dataframe shape...") + nrows = dataframe.map_partitions(len).compute() + max_part_size = min(nrows) + # Use map_overlap to apply forward_fill_partition - dataframe = dataframe.map_overlap(forward_fill_partition, before=0, after=1) + dataframe = dataframe.map_overlap(forward_fill_partition, before=max_part_size+1, after=0) # Remove the NaNs from per_electron channels dataframe = dataframe.dropna(