From 99e10304fe99b8c96253579ea60c81527ea69b4b Mon Sep 17 00:00:00 2001 From: "M. Zain Sohail" Date: Tue, 10 Oct 2023 16:36:20 +0200 Subject: [PATCH] load metadata directly for parquet files for speedup --- sed/loader/flash/loader.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 3a2a785c..c1622c87 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -7,17 +7,18 @@ This can then be saved as a parquet for out-of-sed processing and reread back to access other sed funtionality. """ +import time from functools import reduce from pathlib import Path from typing import List from typing import Sequence from typing import Tuple from typing import Union -import time import dask.dataframe as dd import h5py import numpy as np +import pyarrow as pa from joblib import delayed from joblib import Parallel from natsort import natsorted @@ -25,10 +26,10 @@ from pandas import MultiIndex from pandas import Series +from sed.core import dfops from sed.loader.base.loader import BaseLoader from sed.loader.flash.metadata import MetadataRetriever from sed.loader.utils import parse_h5_keys -from sed.core import dfops class FlashLoader(BaseLoader): @@ -735,14 +736,16 @@ def parquet_handler( # Read all parquet files into one dataframe using dask dataframe = dd.read_parquet(parquet_filenames, calculate_divisions=True) # Channels to fill NaN values - print('Filling nan values...') + print("Filling nan values...") channels: List[str] = self.get_channels_by_format(["per_pulse", "per_train"]) + + overlap = min(pa.parquet.read_metadata(prq).num_rows for prq in parquet_filenames) + dataframe = dfops.forward_fill_lazy( df=dataframe, channels=channels, - before='max', - compute_lengths=True, - iterations=self._config['dataframe'].get('forward_fill_iterations', 2), + before=overlap, + iterations=self._config["dataframe"].get("forward_fill_iterations", 2), ) # Remove the NaNs from per_electron channels dataframe = dataframe.dropna( @@ -845,7 +848,7 @@ def read_dataframe( dataframe = self.parquet_handler(data_parquet_dir, **kwds) metadata = self.parse_metadata() if collect_metadata else {} - print(f'loading complete in {time.time() - t0:.2f} s') + print(f"loading complete in {time.time() - t0:.2f} s") return dataframe, metadata