Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rolling average and energy shift #174

Merged
merged 4 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion sed/calibrator/hextof.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can your break this file up, and put the repsective functions into the calibrator files related to the physical axes? "hextof" is not a physical axis.
So, put energy functions into energy, momentum into momentum, etc. and for the tof-related stuff (dld calibration) make another file tof.py.
I think there should not be anything specific in the main code to "mpes" or "hextof" or so, rather than the loaders, and that's how I tried to design the rest of the code. Off course one or the other function will only work in specific contexts, but those should not be limited to specific instruments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont intend to merge the hextof folder. I just find the thousands of lines of those files scary! 😓

I'll trasnfer everything to its proper place, promised!

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
import dask.dataframe
from dask.diagnostics import ProgressBar


def unravel_8s_detector_time_channel(
Expand Down Expand Up @@ -151,6 +152,145 @@ def dld_time_to_ns(
}
return df, metadata


def rolling_average_on_acquisition_time(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
rolling_group_channel: str,
columns: str = None,
window: float = None,
sigma: float = 2,
config: dict = None,

) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Perform a rolling average with a gaussian weighted window.

In order to preserve the number of points, the first and last "widnow"
number of points are substituted with the original signal.
# TODO: this is currently very slow, and could do with a remake.

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
group_channel: (str): Name of the column on which to group the data
cols (str): Name of the column on which to perform the rolling average
window (float): Size of the rolling average window
sigma (float): number of standard deviations for the gaussian weighting of the window.
a value of 2 corresponds to a gaussian with sigma equal to half the window size.
Smaller values reduce the weighting in the window frame.

Returns:
Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new columns.
"""
if rolling_group_channel is None:
if config is None:
raise ValueError("Either group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
with ProgressBar():
print(f'rolling average over {columns}...')
if isinstance(columns,str):
columns=[columns]
df_ = df.groupby(rolling_group_channel).agg({c:'mean' for c in columns}).compute()
df_['dt'] = pd.to_datetime(df_.index, unit='s')
df_['ts'] = df_.index
for c in columns:
df_[c+'_rolled'] = df_[c].interpolate(
method='nearest'
).rolling(
window,center=True,win_type='gaussian'
).mean(
std=window/sigma
).fillna(df_[c])
df_ = df_.drop(c, axis=1)
if c+'_rolled' in df.columns:
df = df.drop(c+'_rolled',axis=1)
return df.merge(df_,left_on='timeStamp',right_on='ts').drop(['ts','dt'], axis=1)


def shift_energy_axis(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
columns: Union[str,Sequence[str]],
signs: Union[int,Sequence[int]],
energy_column: str = None,
mode: Union[str,Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
config: dict = None,
) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Apply an energy shift to the given column(s).

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
energy_column (str): Name of the column containing the energy values.
column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to.
sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1)
mode (str): The mode of the shift. One of 'direct', 'average' or rolled.
if rolled, window and sigma must be given.
config (dict): Configuration dictionary.
**kwargs: Additional arguments for the rolling average function.
"""
if energy_column is None:
if config is None:
raise ValueError("Either energy_column or config must be given.")
energy_column = config["dataframe"]["energy_column"]

if isinstance(columns,str):
columns=[columns]
if isinstance(signs,int):
signs=[signs]
if isinstance(mode,str):
mode = [mode] * len(columns)
if len(columns) != len(signs):
raise ValueError("column_name and sign must have the same length.")
with ProgressBar(minimum=5,):
if mode == "rolled":
if window is None:
if config is None:
raise ValueError("Either window or config must be given.")
window = config["dataframe"]["rolling_window"]
if sigma is None:
if config is None:
raise ValueError("Either sigma or config must be given.")
sigma = config["dataframe"]["rolling_sigma"]
if rolling_group_channel is None:
if config is None:
raise ValueError("Either rolling_group_channel or config must be given.")
rolling_group_channel = config["dataframe"].get("rolling_group_channel",None)
if rolling_group_channel is None:
raise ValueError("T f mode is 'rolled', rolling_group_channel must be"
"given or present in config.")
print('rolling averages...')
df = rolling_average_on_acquisition_time(
df,
rolling_group_channel=rolling_group_channel,
columns=columns,
window=window,
sigma=sigma,
)
for col, s, m in zip(columns,signs, mode):
s = s/np.abs(s) # enusre s is either +1 or -1
if m == "rolled":
col = col + '_rolled'
if m == "direct" or m == "rolled":
df[col] = df.map_partitions(
lambda x: x[col] + s * x[energy_column], meta=(col, np.float32)
)
elif m == 'mean':
print('computing means...')
col_mean = df[col].mean()
df[col] = df.map_partitions(
lambda x: x[col] + s * (x[energy_column] - col_mean), meta=(col, np.float32)
)
else:
raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.")
metadata: dict[str,Any] = {
"applied": True,
"energy_column": energy_column,
"column_name": columns,
"sign": signs,
}
return df, metadata


def step2ns(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
tof_column: str,
Expand All @@ -172,4 +312,4 @@ def step2ns(
Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new column.
"""
val = df[tof_column].astype(dtype) * tof_binwidth * 2**tof_binning
return val.astype(dtype)
return val.astype(dtype)
137 changes: 137 additions & 0 deletions sed/core/dfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
# Note: some of the functions presented here were
# inspired by https://github.com/mpes-kit/mpes
from typing import Any
from typing import Dict
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -193,3 +195,138 @@ def forward_fill_partition(df):
after=0,
)
return df


def rolling_average_on_acquisition_time(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
rolling_group_channel: str,
columns: str = None,
window: float = None,
sigma: float = 2,
config: dict = None,

) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Perform a rolling average with a gaussian weighted window.

In order to preserve the number of points, the first and last "widnow"
number of points are substituted with the original signal.
# TODO: this is currently very slow, and could do with a remake.

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
group_channel: (str): Name of the column on which to group the data
cols (str): Name of the column on which to perform the rolling average
window (float): Size of the rolling average window
sigma (float): number of standard deviations for the gaussian weighting of the window.
a value of 2 corresponds to a gaussian with sigma equal to half the window size.
Smaller values reduce the weighting in the window frame.

Returns:
Union[pd.DataFrame, dask.dataframe.DataFrame]: Dataframe with the new columns.
"""
if rolling_group_channel is None:
if config is None:
raise ValueError("Either group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
with ProgressBar():
print(f'rolling average over {columns}...')
if isinstance(columns,str):
columns=[columns]
df_ = df.groupby(rolling_group_channel).agg({c:'mean' for c in columns}).compute()
df_['dt'] = pd.to_datetime(df_.index, unit='s')
df_['ts'] = df_.index
for c in columns:
df_[c+'_rolled'] = df_[c].interpolate(
method='nearest'
).rolling(
window,center=True,win_type='gaussian'
).mean(
std=window/sigma
).fillna(df_[c])
df_ = df_.drop(c, axis=1)
if c+'_rolled' in df.columns:
df = df.drop(c+'_rolled',axis=1)
return df.merge(df_,left_on='timeStamp',right_on='ts').drop(['ts','dt'], axis=1)


def apply_energy_shift(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
columns: Union[str,Sequence[str]],
signs: Union[int,Sequence[int]],
energy_column: str = None,
mode: Union[str,Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
config: dict = None,
) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
""" Apply an energy shift to the given column(s).

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
energy_column (str): Name of the column containing the energy values.
column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to.
sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1)
mode (str): The mode of the shift. One of 'direct', 'average' or rolled.
if rolled, window and sigma must be given.
config (dict): Configuration dictionary.
**kwargs: Additional arguments for the rolling average function.
"""
if energy_column is None:
if config is None:
raise ValueError("Either energy_column or config must be given.")
energy_column = config["dataframe"]["energy_column"]
if isinstance(columns,str):
columns=[columns]
if isinstance(signs,int):
signs=[signs]
if isinstance(mode,str):
mode = [mode] * len(columns)
if len(columns) != len(signs):
raise ValueError("column_name and sign must have the same length.")
with ProgressBar(minimum=5,):
if mode == "rolled":
if window is None:
if config is None:
raise ValueError("Either window or config must be given.")
window = config["dataframe"]["rolling_window"]
if sigma is None:
if config is None:
raise ValueError("Either sigma or config must be given.")
sigma = config["dataframe"]["rolling_sigma"]
if rolling_group_channel is None:
if config is None:
raise ValueError("Either rolling_group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
print('rolling averages...')
df = rolling_average_on_acquisition_time(
df,
rolling_group_channel=rolling_group_channel,
columns=columns,
window=window,
sigma=sigma,
)
for col, s, m in zip(columns,signs, mode):
s = s/np.abs(s) # enusre s is either +1 or -1
if m == "rolled":
col = col + '_rolled'
if m == "direct" or m == "rolled":
df[col] = df.map_partitions(
lambda x: x[col] + s * x[energy_column], meta=(col, np.float32)
)
elif m == 'mean':
print('computing means...')
col_mean = df[col].mean()
df[col] = df.map_partitions(
lambda x: x[col] + s * (x[energy_column] - col_mean), meta=(col, np.float32)
)
else:
raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.")
metadata: dict[str,Any] = {
"applied": True,
"energy_column": energy_column,
"column_name": columns,
"sign": signs,
}
return df, metadata

31 changes: 31 additions & 0 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,37 @@ def calibrate_delay_axis(
else:
print(self._dataframe)

def shift_energy_axis(
self,
columns: Union[str,Sequence[str]],
signs: Union[int,Sequence[int]],
mode: Union[str,Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
) -> None:
energy_column = self._config["dataframe"]["energy_column"]
if energy_column not in self._dataframe.columns:
raise ValueError(
f"Energy column {energy_column} not found in dataframe! "
"Run energy calibration first",
)
self._dataframe, metadata = hextof.shift_energy_axis(
df=self._dataframe,
columns=columns,
signs=signs,
mode=mode,
window=window,
sigma=sigma,
rolling_group_channel=rolling_group_channel,
config=self._config,
)
self._attributes.add(
metadata,
"shift_energy_axis",
duplicate_policy="raise",
)

def add_jitter(self, cols: Sequence[str] = None):
"""Add jitter to the selected dataframe columns.

Expand Down
Loading