From bfd3daaee0278ac1f2a9a89bd082aae5b3c4df77 Mon Sep 17 00:00:00 2001 From: Steinn Ymir Agustsson Date: Fri, 13 Oct 2023 13:27:02 +0200 Subject: [PATCH] move more functions --- sed/calibrator/dld.py | 3 +- sed/calibrator/energy.py | 92 +++++++++++++++++++++++++++++++++++++- sed/core/dfops.py | 87 ----------------------------------- sed/loader/flash/loader.py | 4 +- 4 files changed, 94 insertions(+), 92 deletions(-) diff --git a/sed/calibrator/dld.py b/sed/calibrator/dld.py index d1582778..8a24b4e6 100644 --- a/sed/calibrator/dld.py +++ b/sed/calibrator/dld.py @@ -12,7 +12,6 @@ import pandas as pd -# TODO: this could be generalized and moved to dfops for splitting a channel bitwise def unravel_8s_detector_time_channel( df: dask.dataframe.DataFrame, tof_column: str = None, @@ -21,6 +20,8 @@ def unravel_8s_detector_time_channel( ) -> dask.dataframe.DataFrame: """Converts the 8s time in steps to time in steps and sectorID. + # TODO: this could be generalized and moved to dfops for splitting a channel bitwise + The 8s detector encodes the dldSectorID in the 3 least significant bits of the dldTimeSteps channel. diff --git a/sed/calibrator/energy.py b/sed/calibrator/energy.py index 0c08c5d4..f037fd5d 100644 --- a/sed/calibrator/energy.py +++ b/sed/calibrator/energy.py @@ -25,6 +25,7 @@ import xarray as xr from bokeh.io import output_notebook from bokeh.palettes import Category10 as ColorCycle +from dask.diagnostics import ProgressBar from fastdtw import fastdtw from IPython.display import display from lmfit import Minimizer @@ -35,6 +36,7 @@ from scipy.sparse.linalg import lsqr from sed.binning import bin_dataframe +from sed.core import dfops from sed.loader.base.loader import BaseLoader @@ -2127,12 +2129,12 @@ def tof_step_to_ns( raise ValueError("Either tof_ns_column or config must be given.") tof_ns_column = config["dataframe"]["tof_ns_column"] - df[tof_ns_column] = df.map_partitions(step2ns, meta=(tof_column, np.float64)) + df[tof_ns_column] = df.map_partitions(tof2ns, meta=(tof_column, np.float64)) metadata: Dict[str, Any] = {"applied": True, "tof_binwidth": tof_binwidth} return df, metadata -def step2ns( +def tof2ns( df: Union[pd.DataFrame, dask.dataframe.DataFrame], tof_column: str, tof_binwidth: float, @@ -2154,3 +2156,89 @@ def step2ns( """ val = df[tof_column].astype(dtype) * tof_binwidth * 2**tof_binning return val.astype(dtype) + + +def apply_energy_shift( + df: Union[pd.DataFrame, dask.dataframe.DataFrame], + columns: Union[str, Sequence[str]], + signs: Union[int, Sequence[int]], + energy_column: str = None, + mode: Union[str, Sequence[str]] = "direct", + window: float = None, + sigma: float = 2, + rolling_group_channel: str = None, + config: dict = None, +) -> Union[pd.DataFrame, dask.dataframe.DataFrame]: + """Apply an energy shift to the given column(s). + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use. + energy_column (str): Name of the column containing the energy values. + column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to. + sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1) + mode (str): The mode of the shift. One of 'direct', 'average' or rolled. + if rolled, window and sigma must be given. + config (dict): Configuration dictionary. + **kwargs: Additional arguments for the rolling average function. + """ + if energy_column is None: + if config is None: + raise ValueError("Either energy_column or config must be given.") + energy_column = config["dataframe"]["energy_column"] + if isinstance(columns, str): + columns = [columns] + if isinstance(signs, int): + signs = [signs] + if isinstance(mode, str): + mode = [mode] * len(columns) + if len(columns) != len(signs): + raise ValueError("column_name and sign must have the same length.") + with ProgressBar( + minimum=5, + ): + if mode == "rolled": + if window is None: + if config is None: + raise ValueError("Either window or config must be given.") + window = config["dataframe"]["rolling_window"] + if sigma is None: + if config is None: + raise ValueError("Either sigma or config must be given.") + sigma = config["dataframe"]["rolling_sigma"] + if rolling_group_channel is None: + if config is None: + raise ValueError("Either rolling_group_channel or config must be given.") + rolling_group_channel = config["dataframe"]["rolling_group_channel"] + print("rolling averages...") + df = dfops.rolling_average_on_acquisition_time( + df, + rolling_group_channel=rolling_group_channel, + columns=columns, + window=window, + sigma=sigma, + ) + for col, s, m in zip(columns, signs, mode): + s = s / np.abs(s) # enusre s is either +1 or -1 + if m == "rolled": + col = col + "_rolled" + if m == "direct" or m == "rolled": + df[col] = df.map_partitions( + lambda x: x[col] + s * x[energy_column], + meta=(col, np.float32), + ) + elif m == "mean": + print("computing means...") + col_mean = df[col].mean() + df[col] = df.map_partitions( + lambda x: x[col] + s * (x[energy_column] - col_mean), + meta=(col, np.float32), + ) + else: + raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.") + metadata: dict[str, Any] = { + "applied": True, + "energy_column": energy_column, + "column_name": columns, + "sign": signs, + } + return df, metadata diff --git a/sed/core/dfops.py b/sed/core/dfops.py index 4f8fda62..227bfb16 100644 --- a/sed/core/dfops.py +++ b/sed/core/dfops.py @@ -3,7 +3,6 @@ """ # Note: some of the functions presented here were # inspired by https://github.com/mpes-kit/mpes -from typing import Any from typing import Callable from typing import Sequence from typing import Union @@ -245,89 +244,3 @@ def rolling_average_on_acquisition_time( if c + "_rolled" in df.columns: df = df.drop(c + "_rolled", axis=1) return df.merge(df_, left_on="timeStamp", right_on="ts").drop(["ts", "dt"], axis=1) - - -def apply_energy_shift( - df: Union[pd.DataFrame, dask.dataframe.DataFrame], - columns: Union[str, Sequence[str]], - signs: Union[int, Sequence[int]], - energy_column: str = None, - mode: Union[str, Sequence[str]] = "direct", - window: float = None, - sigma: float = 2, - rolling_group_channel: str = None, - config: dict = None, -) -> Union[pd.DataFrame, dask.dataframe.DataFrame]: - """Apply an energy shift to the given column(s). - - Args: - df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use. - energy_column (str): Name of the column containing the energy values. - column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to. - sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1) - mode (str): The mode of the shift. One of 'direct', 'average' or rolled. - if rolled, window and sigma must be given. - config (dict): Configuration dictionary. - **kwargs: Additional arguments for the rolling average function. - """ - if energy_column is None: - if config is None: - raise ValueError("Either energy_column or config must be given.") - energy_column = config["dataframe"]["energy_column"] - if isinstance(columns, str): - columns = [columns] - if isinstance(signs, int): - signs = [signs] - if isinstance(mode, str): - mode = [mode] * len(columns) - if len(columns) != len(signs): - raise ValueError("column_name and sign must have the same length.") - with ProgressBar( - minimum=5, - ): - if mode == "rolled": - if window is None: - if config is None: - raise ValueError("Either window or config must be given.") - window = config["dataframe"]["rolling_window"] - if sigma is None: - if config is None: - raise ValueError("Either sigma or config must be given.") - sigma = config["dataframe"]["rolling_sigma"] - if rolling_group_channel is None: - if config is None: - raise ValueError("Either rolling_group_channel or config must be given.") - rolling_group_channel = config["dataframe"]["rolling_group_channel"] - print("rolling averages...") - df = rolling_average_on_acquisition_time( - df, - rolling_group_channel=rolling_group_channel, - columns=columns, - window=window, - sigma=sigma, - ) - for col, s, m in zip(columns, signs, mode): - s = s / np.abs(s) # enusre s is either +1 or -1 - if m == "rolled": - col = col + "_rolled" - if m == "direct" or m == "rolled": - df[col] = df.map_partitions( - lambda x: x[col] + s * x[energy_column], - meta=(col, np.float32), - ) - elif m == "mean": - print("computing means...") - col_mean = df[col].mean() - df[col] = df.map_partitions( - lambda x: x[col] + s * (x[energy_column] - col_mean), - meta=(col, np.float32), - ) - else: - raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.") - metadata: dict[str, Any] = { - "applied": True, - "energy_column": energy_column, - "column_name": columns, - "sign": signs, - } - return df, metadata diff --git a/sed/loader/flash/loader.py b/sed/loader/flash/loader.py index 807201a2..10baa058 100644 --- a/sed/loader/flash/loader.py +++ b/sed/loader/flash/loader.py @@ -26,7 +26,7 @@ from pandas import MultiIndex from pandas import Series -from sed.calibrator.dld import unravel_8s_detector_time_channel +from sed.calibrator import dld from sed.core import dfops from sed.loader.base.loader import BaseLoader from sed.loader.flash.metadata import MetadataRetriever @@ -596,7 +596,7 @@ def create_dataframe_per_file( df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps")) # correct the 3 bit shift which encodes the detector ID in the 8s time if self._config["dataframe"].get("unravel_8s_detector_time_channel", False): - df = unravel_8s_detector_time_channel(df, config=self._config) + df = dld.unravel_8s_detector_time_channel(df, config=self._config) return df def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> Union[bool, Exception]: