Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add a decorator method to record function calls with parameters #97

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion sed/calibrator/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import numpy as np
import pandas as pd

from sed.core.workflow_recorder import MethodCall
from sed.core.workflow_recorder import track_call


class DelayCalibrator:
"""
Expand All @@ -18,14 +21,19 @@ class DelayCalibrator:
def __init__(
self,
config: dict = None,
tracker: List[MethodCall] = None,
):
"""Initialization of the DelayCalibrator class passes the config dict."""

# pylint: disable=duplicate-code
if config is None:
config = {}

self._config = config

if tracker is None:
tracker = []
self._call_tracker = tracker

self.adc_column = self._config.get("dataframe", {}).get(
"adc_column",
"ADC",
Expand All @@ -35,6 +43,13 @@ def __init__(
"delay",
)

@property
def call_tracker(self) -> List[MethodCall]:
"""List of tracked function calls."""

return self._call_tracker

@track_call
def append_delay_axis(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
Expand Down
99 changes: 68 additions & 31 deletions sed/calibrator/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from silx.io import dictdump

from sed.binning import bin_dataframe
from sed.core.workflow_recorder import MethodCall
from sed.core.workflow_recorder import track_call
from sed.loader.base.loader import BaseLoader


Expand All @@ -53,6 +55,7 @@ def __init__(
traces: np.ndarray = None,
tof: np.ndarray = None,
config: dict = None,
tracker: List[MethodCall] = None,
):
"""The initialization of the EnergyCalibrator class, can happen by passing
the following parameters:
Expand All @@ -73,11 +76,15 @@ def __init__(
if traces is not None and tof is not None and biases is not None:
self.load_data(biases=biases, traces=traces, tof=tof)

# pylint: disable=duplicate-code
if config is None:
config = {}

self._config = config

if tracker is None:
tracker = []
self._call_tracker = tracker

self.featranges: List[Tuple] = [] # Value ranges for feature detection
self.peaks: np.ndarray = np.asarray([])
self.calibration: Dict[Any, Any] = {}
Expand Down Expand Up @@ -123,6 +130,12 @@ def __init__(
},
)

@property
def call_tracker(self) -> List[MethodCall]:
"""List of tracked function calls."""

return self._call_tracker

@property
def ntraces(self) -> int:
"""The number of loaded/calculated traces."""
Expand Down Expand Up @@ -167,6 +180,7 @@ def load_data(
else:
self.traces = self.traces_normed = np.asarray([])

@track_call
def bin_data(
self,
data_files: List[str],
Expand Down Expand Up @@ -241,6 +255,7 @@ def bin_data(
self.tof = np.asarray(tof)
self.biases = np.asarray(biases)

@track_call
def normalize(self, smooth: bool = False, span: int = 7, order: int = 1):
"""Normalize the spectra along an axis.

Expand All @@ -259,6 +274,7 @@ def normalize(self, smooth: bool = False, span: int = 7, order: int = 1):
order=order,
)

@track_call
def add_features(
self,
ranges: Union[List[Tuple], Tuple],
Expand Down Expand Up @@ -316,6 +332,7 @@ def add_features(
elif mode == "replace":
self.featranges = newranges

@track_call
def feature_extract(
self,
ranges: List[Tuple] = None,
Expand Down Expand Up @@ -349,6 +366,7 @@ def feature_extract(
pkwindow=peak_window,
)

@track_call
def calibrate(
self,
ref_id: int = 0,
Expand Down Expand Up @@ -609,6 +627,7 @@ def view( # pylint: disable=dangerous-default-value

pbk.show(fig)

@track_call
def append_energy_axis(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
Expand Down Expand Up @@ -885,14 +904,12 @@ def update(amplitude, x_center, y_center, **kwds):
)

def apply_func(apply: bool): # pylint: disable=unused-argument
self.correction["amplitude"] = amplitude_slider.value
self.correction["center"] = (
x_center_slider.value,
y_center_slider.value,
self.set_energy_correction(
amplitude=amplitude_slider.value,
center=(x_center_slider.value, y_center_slider.value),
correction_type=correction_type,
diameter=diameter_slider.value,
)
self.correction["correction_type"] = correction_type
kwds["diameter"] = diameter_slider.value
self.correction["kwds"] = kwds
amplitude_slider.close()
x_center_slider.close()
y_center_slider.close()
Expand Down Expand Up @@ -920,14 +937,12 @@ def apply_func(apply: bool): # pylint: disable=unused-argument
)

def apply_func(apply: bool): # pylint: disable=unused-argument
self.correction["amplitude"] = amplitude_slider.value
self.correction["center"] = (
x_center_slider.value,
y_center_slider.value,
self.set_energy_correction(
amplitude=amplitude_slider.value,
center=(x_center_slider.value, y_center_slider.value),
correction_type=correction_type,
gamma=gamma_slider.value,
)
self.correction["correction_type"] = correction_type
kwds["gamma"] = gamma_slider.value
self.correction["kwds"] = kwds
amplitude_slider.close()
x_center_slider.close()
y_center_slider.close()
Expand Down Expand Up @@ -955,14 +970,12 @@ def apply_func(apply: bool): # pylint: disable=unused-argument
)

def apply_func(apply: bool): # pylint: disable=unused-argument
self.correction["amplitude"] = amplitude_slider.value
self.correction["center"] = (
x_center_slider.value,
y_center_slider.value,
self.set_energy_correction(
amplitude=amplitude_slider.value,
center=(x_center_slider.value, y_center_slider.value),
correction_type=correction_type,
sigma=sigma_slider.value,
)
self.correction["correction_type"] = correction_type
kwds["sigma"] = sigma_slider.value
self.correction["kwds"] = kwds
amplitude_slider.close()
x_center_slider.close()
y_center_slider.close()
Expand Down Expand Up @@ -1015,16 +1028,14 @@ def apply_func(apply: bool): # pylint: disable=unused-argument
)

def apply_func(apply: bool): # pylint: disable=unused-argument
self.correction["amplitude"] = amplitude_slider.value
self.correction["center"] = (
x_center_slider.value,
y_center_slider.value,
self.set_energy_correction(
amplitude=amplitude_slider.value,
center=(x_center_slider.value, y_center_slider.value),
correction_type=correction_type,
gamma=gamma_slider.value,
amplitude2=amplitude2_slider.value,
gamma2=gamma2_slider.value,
)
self.correction["correction_type"] = correction_type
kwds["gamma"] = gamma_slider.value
kwds["amplitude2"] = amplitude2_slider.value
kwds["gamma2"] = gamma2_slider.value
self.correction["kwds"] = kwds
amplitude_slider.close()
x_center_slider.close()
y_center_slider.close()
Expand All @@ -1044,6 +1055,32 @@ def apply_func(apply: bool): # pylint: disable=unused-argument
if apply:
apply_func(True)

@track_call
def set_energy_correction(
self,
amplitude: float,
correction_type: str,
center: Tuple[float, float],
**kwds,
):
"""Set the supplied energy correction parameters in the class attribute

Args:
amplitude (float):
Amplitude of the time-of-flight correction term
correction_type (str):
Type of correction to apply to the TOF axis. Defaults to config value.
center (Tuple(float, float)):
Center pixel (x, y) of the correction term.
**kwds:
Additional keyword arguments specific to the various correction_types
"""
self.correction["amplitude"] = amplitude
self.correction["center"] = center
self.correction["correction_type"] = correction_type
self.correction["kwds"] = kwds

@track_call
def apply_energy_correction(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
Expand Down
28 changes: 27 additions & 1 deletion sed/calibrator/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from symmetrize import sym
from symmetrize import tps

from sed.core.workflow_recorder import MethodCall
from sed.core.workflow_recorder import track_call


class MomentumCorrector:
"""
Expand All @@ -42,6 +45,7 @@ def __init__(
bin_ranges: List[Tuple] = None,
rotsym: int = 6,
config: dict = None,
tracker: List[MethodCall] = None,
):
"""
Parameters:
Expand All @@ -60,11 +64,15 @@ def __init__(
if data is not None:
self.load_data(data=data, bin_ranges=bin_ranges, rotsym=rotsym)

# pylint: disable=duplicate-code
if config is None:
config = {}

self._config = config

if tracker is None:
tracker = []
self._call_tracker = tracker

self.detector_ranges = self._config.get("momentum", {}).get(
"detector_ranges",
[[0, 2048], [0, 2048]],
Expand Down Expand Up @@ -93,6 +101,12 @@ def __init__(
self.transformations: Dict[Any, Any] = {}
self.calibration: Dict[Any, Any] = {}

@property
def call_tracker(self) -> List[MethodCall]:
"""List of tracked function calls."""

return self._call_tracker

@property
def features(self) -> dict:
"""Dictionary of detected features for the symmetrization process.
Expand Down Expand Up @@ -264,6 +278,7 @@ def apply_fun(apply: bool): # pylint: disable=unused-argument
if apply:
apply_fun(True)

@track_call
def select_slice(
self,
selector: Union[slice, List[int], int],
Expand Down Expand Up @@ -293,6 +308,7 @@ def select_slice(
elif self.img_ndim == 2:
raise ValueError("Input image dimension is already 2!")

@track_call
def add_features(
self,
peaks: np.ndarray,
Expand Down Expand Up @@ -355,6 +371,7 @@ def add_features(
self.mcvdist = self.mdist
self.mvvdist = self.mdist

@track_call
def feature_extract(
self,
image: np.ndarray = None,
Expand Down Expand Up @@ -424,6 +441,7 @@ def calc_symmetry_scores(self, symtype: str = "rotation") -> float:

return csm

@track_call
def spline_warp_estimate(
self,
image: np.ndarray = None,
Expand Down Expand Up @@ -515,6 +533,7 @@ def spline_warp_estimate(

return self.slice_corrected

@track_call
def apply_correction(
self,
image: np.ndarray,
Expand Down Expand Up @@ -591,6 +610,7 @@ def update_deformation(self, rdeform: np.ndarray, cdeform: np.ndarray):
cval=np.nan,
)

@track_call
def coordinate_transform(
self,
transform_type: str,
Expand Down Expand Up @@ -741,6 +761,9 @@ def coordinate_transform(
rdeform,
cdeform,
)
else:
# remove tracked call to function if not keeping transformation
self.call_tracker.pop()

return slice_transformed

Expand Down Expand Up @@ -1068,6 +1091,7 @@ def view( # pylint: disable=dangerous-default-value

pbk.show(fig)

@track_call
def calibrate(
self,
point_a: Union[np.ndarray, List[int]],
Expand Down Expand Up @@ -1174,6 +1198,7 @@ def calibrate(

return self.calibration

@track_call
def apply_distortion_correction(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
Expand Down Expand Up @@ -1233,6 +1258,7 @@ def apply_distortion_correction(
)
return out_df

@track_call
def append_k_axis(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
Expand Down
Loading