Skip to content

Commit

Permalink
add rolling average and energy shift
Browse files Browse the repository at this point in the history
  • Loading branch information
steinnymir committed Oct 11, 2023
1 parent ccff3f8 commit 27adcc3
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions sed/core/dfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
# Note: some of the functions presented here were
# inspired by https://github.com/mpes-kit/mpes
from typing import Any
from typing import Dict
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -193,3 +195,132 @@ def forward_fill_partition(df):
after=0,
)
return df


def rolling_average_on_acquisition_time(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
rolling_group_channel: str,
columns: str = None,
window: float = None,
sigma: float = 2,
config: dict = None,

) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Perform a rolling average with a gaussian weighted window.
In order to preserve the number of points, the first and last "widnow"
number of points are substituted with the original signal.
# TODO: this is currently very slow, and could do with a remake.
Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
group_channel: (str): Name of the column on which to group the data
cols (str): Name of the column on which to perform the rolling average
window (float): Size of the rolling average window
sigma (float): number of standard deviations for the gaussian weighting of the window.
a value of 2 corresponds to a gaussian with sigma equal to half the window size.
Smaller values reduce the weighting in the window frame.
Returns:
Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new columns.
"""
if group_channel is None:
if config is None:
raise ValueError("Either group_channel or config must be given.")
group_channel = config["dataframe"]["rolling_group_channel"]
with ProgressBar():
print(f'rolling average over {columns}...')
if isinstance(columns,str):
columns=[columns]
df_ = df.groupby(rolling_group_channel).agg({c:'mean' for c in columns}).compute()
df_['dt'] = pd.to_datetime(df_.index, unit='s')
df_['ts'] = df_.index
for c in columns:
df_[c+'_rolled'] = df_[c].interpolate(method='nearest').rolling(window,center=True,win_type='gaussian').mean(std=window/sigma).fillna(df_[c])
df_ = df_.drop(c, axis=1)
if c+'_rolled' in df.columns:
df = df.drop(c+'_rolled',axis=1)
return df.merge(df_,left_on='timeStamp',right_on='ts').drop(['ts','dt'], axis=1)


def apply_energy_shift(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
columns: Union[str,Sequence[str]],
signs: Union[int,Sequence[int]],
energy_column: str = None,
mode: Union[str,Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
config: dict = None,
) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Apply an energy shift to the given column(s).
Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
energy_column (str): Name of the column containing the energy values.
column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to.
sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1)
mode (str): The mode of the shift. One of 'direct', 'average' or rolled.
if rolled, window and sigma must be given.
config (dict): Configuration dictionary.
**kwargs: Additional arguments for the rolling average function.
"""
if energy_column is None:
if config is None:
raise ValueError("Either energy_column or config must be given.")
energy_column = config["dataframe"]["energy_column"]
if isinstance(columns,str):
columns=[columns]
if isinstance(signs,int):
signs=[signs]
if isinstance(mode,str):
mode = [mode] * len(columns)
if len(columns) != len(signs):
raise ValueError("column_name and sign must have the same length.")
with ProgressBar(minimum=5,):
if mode == "rolled":
if window is None:
if config is None:
raise ValueError("Either window or config must be given.")
window = config["dataframe"]["rolling_window"]
if sigma is None:
if config is None:
raise ValueError("Either sigma or config must be given.")
sigma = config["dataframe"]["rolling_sigma"]
if rolling_group_channel is None:
if config is None:
raise ValueError("Either rolling_group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
print('rolling averages...')
df = rolling_average_on_acquisition_time(
df,
rolling_group_channel=rolling_group_channel,
columns=columns,
window=window,
sigma=sigma,
)
for col, s, m in zip(columns,signs, mode):
s = s/np.abs(s) # enusre s is either +1 or -1
if m == "rolled":
col = col + '_rolled'
if m == "direct" or m == "rolled":
df[col] = df.map_partitions(
lambda x: x[col] + s * x[energy_column], meta=(col, np.float32)
)
elif m == 'mean':
print('computing means...')
col_mean = df[col].mean()
df[col] = df.map_partitions(
lambda x: x[col] + s * (x[energy_column] - col_mean), meta=(col, np.float32)
)
else:
raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.")
metadata: dict[str,Any] = {
"applied": True,
"energy_column": energy_column,
"column_name": columns,
"sign": signs,
}
return df, metadata

0 comments on commit 27adcc3

Please sign in to comment.