Skip to content

Commit

Permalink
use configmodel in processor class
Browse files Browse the repository at this point in the history
  • Loading branch information
zain-sohail committed Sep 16, 2024
1 parent f3de17a commit 9e9c47b
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,8 @@ def apply_momentum_correction(
- **inv_dfield** (np.ndarray, optional): Inverse deformation field.
"""
x_column = self._config["dataframe"]["x_column"]
y_column = self._config["dataframe"]["y_column"]
x_column = self._config["dataframe"]["columns"]["x"]
y_column = self._config["dataframe"]["columns"]["y"]

if self._dataframe is not None:
logger.info("Adding corrected X/Y columns to dataframe:")
Expand Down Expand Up @@ -948,8 +948,8 @@ def apply_momentum_calibration(
Defaults to False.
**kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
"""
x_column = self._config["dataframe"]["x_column"]
y_column = self._config["dataframe"]["y_column"]
x_column = self._config["dataframe"]["columns"]["x"]
y_column = self._config["dataframe"]["columns"]["y"]

if self._dataframe is not None:
logger.info("Adding kx/ky columns to dataframe:")
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def apply_energy_correction(
**kwds:
Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
"""
tof_column = self._config["dataframe"]["tof_column"]
tof_column = self._config["dataframe"]["columns"]["tof"]

if self._dataframe is not None:
logger.info("Applying energy correction to dataframe...")
Expand Down Expand Up @@ -1168,16 +1168,16 @@ def load_bias_series(
if binned_data is not None:
if isinstance(binned_data, xr.DataArray):
if (
self._config["dataframe"]["tof_column"] not in binned_data.dims
or self._config["dataframe"]["bias_column"] not in binned_data.dims
self._config["dataframe"]["columns"]["tof"] not in binned_data.dims
or self._config["dataframe"]["columns"]["bias"] not in binned_data.dims
):
raise ValueError(
"If binned_data is provided as an xarray, it needs to contain dimensions "
f"'{self._config['dataframe']['tof_column']}' and "
f"'{self._config['dataframe']['bias_column']}'!.",
)
tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
tof = binned_data.coords[self._config["dataframe"]["columns"]["tof"]].values
biases = binned_data.coords[self._config["dataframe"]["columns"]["bias"]].values
traces = binned_data.values[:, :]
else:
try:
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def append_energy_axis(
**kwds:
Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
"""
tof_column = self._config["dataframe"]["tof_column"]
tof_column = self._config["dataframe"]["columns"]["tof"]

if self._dataframe is not None:
logger.info("Adding energy column to dataframe:")
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def add_energy_offset(
Raises:
ValueError: If the energy column is not in the dataframe.
"""
energy_column = self._config["dataframe"]["energy_column"]
energy_column = self._config["dataframe"]["columns"]["energy"]
if energy_column not in self._dataframe.columns:
raise ValueError(
f"Energy column {energy_column} not found in dataframe! "
Expand Down Expand Up @@ -1605,7 +1605,7 @@ def append_tof_ns_axis(
**kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
"""
tof_column = self._config["dataframe"]["tof_column"]
tof_column = self._config["dataframe"]["columns"]["tof"]

if self._dataframe is not None:
logger.info("Adding time-of-flight column in nanoseconds to dataframe.")
Expand Down Expand Up @@ -1652,7 +1652,7 @@ def align_dld_sectors(
Defaults to False.
**kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
"""
tof_column = self._config["dataframe"]["tof_column"]
tof_column = self._config["dataframe"]["columns"]["tof"]

if self._dataframe is not None:
logger.info("Aligning 8s sectors of dataframe")
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def calibrate_delay_axis(
Defaults to False.
**kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
"""
adc_column = self._config["dataframe"]["adc_column"]
adc_column = self._config["dataframe"]["columns"]["adc"]
if adc_column not in self._dataframe.columns:
raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")

Expand Down Expand Up @@ -1822,7 +1822,7 @@ def add_delay_offset(
Raises:
ValueError: If the delay column is not in the dataframe.
"""
delay_column = self._config["dataframe"]["delay_column"]
delay_column = self._config["dataframe"]["columns"]["delay"]
if delay_column not in self._dataframe.columns:
raise ValueError(f"Delay column {delay_column} not found in dataframe! ")

Expand Down Expand Up @@ -1945,7 +1945,7 @@ def add_jitter(
cols = self._config["dataframe"]["jitter_cols"]
for loc, col in enumerate(cols):
if col.startswith("@"):
cols[loc] = self._config["dataframe"].get(col.strip("@"))
cols[loc] = self._config["dataframe"]["columns"].get(col.strip("@"))

if amps is None:
amps = self._config["dataframe"]["jitter_amps"]
Expand Down Expand Up @@ -2005,7 +2005,7 @@ def add_time_stamped_data(
"""
time_stamp_column = kwds.pop(
"time_stamp_column",
self._config["dataframe"].get("time_stamp_alias", ""),
self._config["dataframe"]["columns"].get("timestamp", ""),
)

if time_stamps is None and data is None:
Expand Down Expand Up @@ -2080,7 +2080,7 @@ def pre_binning(
axes = self._config["momentum"]["axes"]
for loc, axis in enumerate(axes):
if axis.startswith("@"):
axes[loc] = self._config["dataframe"].get(axis.strip("@"))
axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))

if bins is None:
bins = self._config["momentum"]["bins"]
Expand Down Expand Up @@ -2314,14 +2314,14 @@ def get_normalization_histogram(
self._dataframe.partitions[df_partitions],
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
self._config["dataframe"]["columns"]["timestamp"],
)
else:
self._normalization_histogram = normalization_histogram_from_timestamps(
self._dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["time_stamp_alias"],
self._config["dataframe"]["columns"]["timestamp"],
)
else:
if df_partitions is not None:
Expand Down Expand Up @@ -2387,13 +2387,13 @@ def view_event_histogram(
axes = list(axes)
for loc, axis in enumerate(axes):
if axis.startswith("@"):
axes[loc] = self._config["dataframe"].get(axis.strip("@"))
axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))
if ranges is None:
ranges = list(self._config["histogram"]["ranges"])
for loc, axis in enumerate(axes):
if axis == self._config["dataframe"]["tof_column"]:
if axis == self._config["dataframe"]["columns"]["tof"]:
ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
elif axis == self._config["dataframe"]["adc_column"]:
elif axis == self._config["dataframe"]["columns"]["adc"]:
ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]

input_types = map(type, [axes, bins, ranges])
Expand Down

0 comments on commit 9e9c47b

Please sign in to comment.