diff --git a/sed/core/dfops.py b/sed/core/dfops.py index ecef954f..7cd8be1e 100644 --- a/sed/core/dfops.py +++ b/sed/core/dfops.py @@ -3,6 +3,8 @@ """ # Note: some of the functions presented here were # inspired by https://github.com/mpes-kit/mpes +from typing import Any +from typing import Dict from typing import Callable from typing import Sequence from typing import Union @@ -193,3 +195,132 @@ def forward_fill_partition(df): after=0, ) return df + + +def rolling_average_on_acquisition_time( + df: Union[pd.DataFrame, dask.dataframe.DataFrame], + rolling_group_channel: str, + columns: str = None, + window: float = None, + sigma: float = 2, + config: dict = None, + +) -> Union[pd.DataFrame, dask.dataframe.DataFrame]: + """ Perform a rolling average with a gaussian weighted window. + + In order to preserve the number of points, the first and last "widnow" + number of points are substituted with the original signal. + # TODO: this is currently very slow, and could do with a remake. + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use. + group_channel: (str): Name of the column on which to group the data + cols (str): Name of the column on which to perform the rolling average + window (float): Size of the rolling average window + sigma (float): number of standard deviations for the gaussian weighting of the window. + a value of 2 corresponds to a gaussian with sigma equal to half the window size. + Smaller values reduce the weighting in the window frame. + + Returns: + Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new columns. + """ + if group_channel is None: + if config is None: + raise ValueError("Either group_channel or config must be given.") + group_channel = config["dataframe"]["rolling_group_channel"] + with ProgressBar(): + print(f'rolling average over {columns}...') + if isinstance(columns,str): + columns=[columns] + df_ = df.groupby(rolling_group_channel).agg({c:'mean' for c in columns}).compute() + df_['dt'] = pd.to_datetime(df_.index, unit='s') + df_['ts'] = df_.index + for c in columns: + df_[c+'_rolled'] = df_[c].interpolate(method='nearest').rolling(window,center=True,win_type='gaussian').mean(std=window/sigma).fillna(df_[c]) + df_ = df_.drop(c, axis=1) + if c+'_rolled' in df.columns: + df = df.drop(c+'_rolled',axis=1) + return df.merge(df_,left_on='timeStamp',right_on='ts').drop(['ts','dt'], axis=1) + + +def apply_energy_shift( + df: Union[pd.DataFrame, dask.dataframe.DataFrame], + columns: Union[str,Sequence[str]], + signs: Union[int,Sequence[int]], + energy_column: str = None, + mode: Union[str,Sequence[str]] = "direct", + window: float = None, + sigma: float = 2, + rolling_group_channel: str = None, + config: dict = None, +) -> Union[pd.DataFrame, dask.dataframe.DataFrame]: + """ Apply an energy shift to the given column(s). + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use. + energy_column (str): Name of the column containing the energy values. + column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to. + sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1) + mode (str): The mode of the shift. One of 'direct', 'average' or rolled. + if rolled, window and sigma must be given. + config (dict): Configuration dictionary. + **kwargs: Additional arguments for the rolling average function. + """ + if energy_column is None: + if config is None: + raise ValueError("Either energy_column or config must be given.") + energy_column = config["dataframe"]["energy_column"] + if isinstance(columns,str): + columns=[columns] + if isinstance(signs,int): + signs=[signs] + if isinstance(mode,str): + mode = [mode] * len(columns) + if len(columns) != len(signs): + raise ValueError("column_name and sign must have the same length.") + with ProgressBar(minimum=5,): + if mode == "rolled": + if window is None: + if config is None: + raise ValueError("Either window or config must be given.") + window = config["dataframe"]["rolling_window"] + if sigma is None: + if config is None: + raise ValueError("Either sigma or config must be given.") + sigma = config["dataframe"]["rolling_sigma"] + if rolling_group_channel is None: + if config is None: + raise ValueError("Either rolling_group_channel or config must be given.") + rolling_group_channel = config["dataframe"]["rolling_group_channel"] + print('rolling averages...') + df = rolling_average_on_acquisition_time( + df, + rolling_group_channel=rolling_group_channel, + columns=columns, + window=window, + sigma=sigma, + ) + for col, s, m in zip(columns,signs, mode): + s = s/np.abs(s) # enusre s is either +1 or -1 + if m == "rolled": + col = col + '_rolled' + if m == "direct" or m == "rolled": + df[col] = df.map_partitions( + lambda x: x[col] + s * x[energy_column], meta=(col, np.float32) + ) + elif m == 'mean': + print('computing means...') + col_mean = df[col].mean() + df[col] = df.map_partitions( + lambda x: x[col] + s * (x[energy_column] - col_mean), meta=(col, np.float32) + ) + else: + raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.") + metadata: dict[str,Any] = { + "applied": True, + "energy_column": energy_column, + "column_name": columns, + "sign": signs, + } + return df, metadata + \ No newline at end of file