Skip to content

Commit

Permalink
move more functions
Browse files Browse the repository at this point in the history
  • Loading branch information
steinnymir committed Oct 13, 2023
1 parent 98c1857 commit bfd3daa
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 92 deletions.
3 changes: 2 additions & 1 deletion sed/calibrator/dld.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd


# TODO: this could be generalized and moved to dfops for splitting a channel bitwise
def unravel_8s_detector_time_channel(
df: dask.dataframe.DataFrame,
tof_column: str = None,
Expand All @@ -21,6 +20,8 @@ def unravel_8s_detector_time_channel(
) -> dask.dataframe.DataFrame:
"""Converts the 8s time in steps to time in steps and sectorID.
# TODO: this could be generalized and moved to dfops for splitting a channel bitwise
The 8s detector encodes the dldSectorID in the 3 least significant bits of the
dldTimeSteps channel.
Expand Down
92 changes: 90 additions & 2 deletions sed/calibrator/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import xarray as xr
from bokeh.io import output_notebook
from bokeh.palettes import Category10 as ColorCycle
from dask.diagnostics import ProgressBar
from fastdtw import fastdtw
from IPython.display import display
from lmfit import Minimizer
Expand All @@ -35,6 +36,7 @@
from scipy.sparse.linalg import lsqr

from sed.binning import bin_dataframe
from sed.core import dfops
from sed.loader.base.loader import BaseLoader


Expand Down Expand Up @@ -2127,12 +2129,12 @@ def tof_step_to_ns(
raise ValueError("Either tof_ns_column or config must be given.")
tof_ns_column = config["dataframe"]["tof_ns_column"]

df[tof_ns_column] = df.map_partitions(step2ns, meta=(tof_column, np.float64))
df[tof_ns_column] = df.map_partitions(tof2ns, meta=(tof_column, np.float64))
metadata: Dict[str, Any] = {"applied": True, "tof_binwidth": tof_binwidth}
return df, metadata


def step2ns(
def tof2ns(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
tof_column: str,
tof_binwidth: float,
Expand All @@ -2154,3 +2156,89 @@ def step2ns(
"""
val = df[tof_column].astype(dtype) * tof_binwidth * 2**tof_binning
return val.astype(dtype)


def apply_energy_shift(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
columns: Union[str, Sequence[str]],
signs: Union[int, Sequence[int]],
energy_column: str = None,
mode: Union[str, Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
config: dict = None,
) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
"""Apply an energy shift to the given column(s).
Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
energy_column (str): Name of the column containing the energy values.
column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to.
sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1)
mode (str): The mode of the shift. One of 'direct', 'average' or rolled.
if rolled, window and sigma must be given.
config (dict): Configuration dictionary.
**kwargs: Additional arguments for the rolling average function.
"""
if energy_column is None:
if config is None:
raise ValueError("Either energy_column or config must be given.")
energy_column = config["dataframe"]["energy_column"]
if isinstance(columns, str):
columns = [columns]
if isinstance(signs, int):
signs = [signs]
if isinstance(mode, str):
mode = [mode] * len(columns)
if len(columns) != len(signs):
raise ValueError("column_name and sign must have the same length.")
with ProgressBar(
minimum=5,
):
if mode == "rolled":
if window is None:
if config is None:
raise ValueError("Either window or config must be given.")
window = config["dataframe"]["rolling_window"]
if sigma is None:
if config is None:
raise ValueError("Either sigma or config must be given.")
sigma = config["dataframe"]["rolling_sigma"]
if rolling_group_channel is None:
if config is None:
raise ValueError("Either rolling_group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
print("rolling averages...")
df = dfops.rolling_average_on_acquisition_time(
df,
rolling_group_channel=rolling_group_channel,
columns=columns,
window=window,
sigma=sigma,
)
for col, s, m in zip(columns, signs, mode):
s = s / np.abs(s) # enusre s is either +1 or -1
if m == "rolled":
col = col + "_rolled"
if m == "direct" or m == "rolled":
df[col] = df.map_partitions(
lambda x: x[col] + s * x[energy_column],
meta=(col, np.float32),
)
elif m == "mean":
print("computing means...")
col_mean = df[col].mean()
df[col] = df.map_partitions(
lambda x: x[col] + s * (x[energy_column] - col_mean),
meta=(col, np.float32),
)
else:
raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.")
metadata: dict[str, Any] = {
"applied": True,
"energy_column": energy_column,
"column_name": columns,
"sign": signs,
}
return df, metadata
87 changes: 0 additions & 87 deletions sed/core/dfops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
# Note: some of the functions presented here were
# inspired by https://github.com/mpes-kit/mpes
from typing import Any
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -245,89 +244,3 @@ def rolling_average_on_acquisition_time(
if c + "_rolled" in df.columns:
df = df.drop(c + "_rolled", axis=1)
return df.merge(df_, left_on="timeStamp", right_on="ts").drop(["ts", "dt"], axis=1)


def apply_energy_shift(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
columns: Union[str, Sequence[str]],
signs: Union[int, Sequence[int]],
energy_column: str = None,
mode: Union[str, Sequence[str]] = "direct",
window: float = None,
sigma: float = 2,
rolling_group_channel: str = None,
config: dict = None,
) -> Union[pd.DataFrame, dask.dataframe.DataFrame]:
"""Apply an energy shift to the given column(s).
Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
energy_column (str): Name of the column containing the energy values.
column_name (Union[str,Sequence[str]]): Name of the column(s) to apply the shift to.
sign (Union[int,Sequence[int]]): Sign of the shift to apply. (+1 or -1)
mode (str): The mode of the shift. One of 'direct', 'average' or rolled.
if rolled, window and sigma must be given.
config (dict): Configuration dictionary.
**kwargs: Additional arguments for the rolling average function.
"""
if energy_column is None:
if config is None:
raise ValueError("Either energy_column or config must be given.")
energy_column = config["dataframe"]["energy_column"]
if isinstance(columns, str):
columns = [columns]
if isinstance(signs, int):
signs = [signs]
if isinstance(mode, str):
mode = [mode] * len(columns)
if len(columns) != len(signs):
raise ValueError("column_name and sign must have the same length.")
with ProgressBar(
minimum=5,
):
if mode == "rolled":
if window is None:
if config is None:
raise ValueError("Either window or config must be given.")
window = config["dataframe"]["rolling_window"]
if sigma is None:
if config is None:
raise ValueError("Either sigma or config must be given.")
sigma = config["dataframe"]["rolling_sigma"]
if rolling_group_channel is None:
if config is None:
raise ValueError("Either rolling_group_channel or config must be given.")
rolling_group_channel = config["dataframe"]["rolling_group_channel"]
print("rolling averages...")
df = rolling_average_on_acquisition_time(
df,
rolling_group_channel=rolling_group_channel,
columns=columns,
window=window,
sigma=sigma,
)
for col, s, m in zip(columns, signs, mode):
s = s / np.abs(s) # enusre s is either +1 or -1
if m == "rolled":
col = col + "_rolled"
if m == "direct" or m == "rolled":
df[col] = df.map_partitions(
lambda x: x[col] + s * x[energy_column],
meta=(col, np.float32),
)
elif m == "mean":
print("computing means...")
col_mean = df[col].mean()
df[col] = df.map_partitions(
lambda x: x[col] + s * (x[energy_column] - col_mean),
meta=(col, np.float32),
)
else:
raise ValueError(f"mode must be one of 'direct', 'mean' or 'rolled'. Got {m}.")
metadata: dict[str, Any] = {
"applied": True,
"energy_column": energy_column,
"column_name": columns,
"sign": signs,
}
return df, metadata
4 changes: 2 additions & 2 deletions sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pandas import MultiIndex
from pandas import Series

from sed.calibrator.dld import unravel_8s_detector_time_channel
from sed.calibrator import dld
from sed.core import dfops
from sed.loader.base.loader import BaseLoader
from sed.loader.flash.metadata import MetadataRetriever
Expand Down Expand Up @@ -596,7 +596,7 @@ def create_dataframe_per_file(
df = df.dropna(subset=self._config["dataframe"].get("tof_column", "dldTimeSteps"))
# correct the 3 bit shift which encodes the detector ID in the 8s time
if self._config["dataframe"].get("unravel_8s_detector_time_channel", False):
df = unravel_8s_detector_time_channel(df, config=self._config)
df = dld.unravel_8s_detector_time_channel(df, config=self._config)
return df

def create_buffer_file(self, h5_path: Path, parquet_path: Path) -> Union[bool, Exception]:
Expand Down

0 comments on commit bfd3daa

Please sign in to comment.