Skip to content

Commit

Permalink
Merge pull request #20 from OpenCOMPES/conversion_parameters
Browse files Browse the repository at this point in the history
Conversion parameters
  • Loading branch information
rettigl authored Mar 12, 2024
2 parents a404259 + 9438445 commit 0671200
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 297 deletions.
255 changes: 104 additions & 151 deletions specsanalyzer/convert.py

Large diffs are not rendered by default.

195 changes: 103 additions & 92 deletions specsanalyzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from IPython.display import display

from specsanalyzer import io
from specsanalyzer.config import complete_dictionary
from specsanalyzer.config import parse_config
from specsanalyzer.convert import calculate_matrix_correction
from specsanalyzer.convert import get_damatrix_fromcalib2d
from specsanalyzer.convert import physical_unit_data
from specsanalyzer.img_tools import crop_xarray
from specsanalyzer.img_tools import fourier_filter_2d
Expand Down Expand Up @@ -54,12 +56,10 @@ def __init__(
self._data_array = None
self.print_msg = True
try:
self._config["calib2d_dict"] = io.parse_calib2d_to_dict(
self._config["calib2d_file"],
)
self._calib2d = io.parse_calib2d_to_dict(self._config["calib2d_file"])

except FileNotFoundError: # default location relative to package directory
self._config["calib2d_dict"] = io.parse_calib2d_to_dict(
self._calib2d = io.parse_calib2d_to_dict(
os.path.join(package_dir, self._config["calib2d_file"]),
)

Expand All @@ -76,17 +76,17 @@ def __repr__(self):
return pretty_str if pretty_str is not None else ""

@property
def config(self):
def config(self) -> dict:
"""Get config"""
return self._config

@config.setter
def config(self, config: dict | str):
"""Set config"""
self._config = parse_config(config)
@property
def calib2d(self) -> dict:
"""Get calib2d dict"""
return self._calib2d

@property
def correction_matrix_dict(self):
def correction_matrix_dict(self) -> dict:
"""Get correction_matrix_dict"""
return self._correction_matrix_dict

Expand All @@ -97,6 +97,7 @@ def convert_image(
kinetic_energy: float,
pass_energy: float,
work_function: float,
conversion_parameters: dict = None,
**kwds,
) -> xr.DataArray:
"""Converts an imagin in physical unit data, angle vs energy
Expand All @@ -108,44 +109,85 @@ def convert_image(
kinetic_energy (float): set analyser kinetic energy
pass_energy (float): set analyser pass energy
work_function (float): set analyser work function
conversion_parameters (dict, optional): dictionary of conversion parameters,
overwriting determination from calib2d file. Defaults to None.
Returns:
xr.DataArray: xarray containg the corrected data and kinetic and angle axis
"""
if conversion_parameters is None:
conversion_parameters = {}
else:
conversion_parameters = conversion_parameters.copy()

apply_fft_filter = kwds.pop("apply_fft_filter", self._config.get("apply_fft_filter", False))
binning = kwds.pop("binning", self._config.get("binning", 1))
if "apply_fft_filter" not in conversion_parameters.keys():
conversion_parameters["apply_fft_filter"] = kwds.pop(
"apply_fft_filter",
self._config.get("apply_fft_filter", False),
)
if "binning" not in conversion_parameters.keys():
conversion_parameters["binning"] = kwds.pop("binning", self._config.get("binning", 1))
if "rotation_angle" not in conversion_parameters.keys():
conversion_parameters["rotation_angle"] = kwds.pop(
"rotation_angle",
self._config.get("rotation_angle", 0),
)

if apply_fft_filter:
if conversion_parameters["apply_fft_filter"]:
try:
fft_filter_peaks = kwds.pop("fft_filter_peaks", self._config["fft_filter_peaks"])
img = fourier_filter_2d(raw_img, fft_filter_peaks)
if "fft_filter_peaks" not in conversion_parameters.keys():
conversion_parameters["fft_filter_peaks"] = kwds.pop(
"fft_filter_peaks",
self._config["fft_filter_peaks"],
)
img = fourier_filter_2d(raw_img, conversion_parameters["fft_filter_peaks"])
except KeyError:
img = raw_img
conversion_parameters["apply_fft_filter"] = False
else:
img = raw_img

rotation_angle = kwds.pop("rotation_angle", self._config.get("rotation_angle", 0))

if rotation_angle:
img_rotated = imutils.rotate(img, angle=rotation_angle)
if conversion_parameters["rotation_angle"]:
img_rotated = imutils.rotate(img, angle=conversion_parameters["rotation_angle"])
img = img_rotated

# look for the lens mode in the dictionary
try:
supported_angle_modes = self._config["calib2d_dict"]["supported_angle_modes"]
supported_space_modes = self._config["calib2d_dict"]["supported_space_modes"]
# pylint: disable=duplicate-code
except KeyError as exc:
raise KeyError(
"The supported modes were not found in the calib2d dictionary",
) from exc

if lens_mode not in [*supported_angle_modes, *supported_space_modes]:
raise ValueError(
f"convert_image: unsupported lens mode: '{lens_mode}'",
if "lens_mode" not in conversion_parameters.keys():
conversion_parameters["lens_mode"] = lens_mode
conversion_parameters["kinetic_energy"] = kinetic_energy
conversion_parameters["pass_energy"] = pass_energy
conversion_parameters["work_function"] = work_function
# Determine conversion parameters from calib2d
(
conversion_parameters["a_inner"],
conversion_parameters["da_matrix"],
conversion_parameters["retardation_ratio"],
conversion_parameters["source"],
conversion_parameters["dims"],
) = get_damatrix_fromcalib2d(
lens_mode=lens_mode,
kinetic_energy=kinetic_energy,
pass_energy=pass_energy,
work_function=work_function,
calib2d_dict=self._calib2d,
)
conversion_parameters["e_shift"] = np.array(self._calib2d["eShift"])
conversion_parameters["de1"] = [self._calib2d["De1"]]
conversion_parameters["e_range"] = self._calib2d["eRange"]
conversion_parameters["a_range"] = self._calib2d[lens_mode]["default"]["aRange"]
conversion_parameters["pixel_size"] = (
self._config["pixel_size"] * self._config["binning"]
)
conversion_parameters["magnification"] = self._config["magnification"]
conversion_parameters["angle_offset_px"] = kwds.get(
"angle_offset_px",
self._config.get("angle_offset_px", 0),
)
conversion_parameters["energy_offset_px"] = kwds.get(
"energy_offset_px",
self._config.get("energy_offset_px", 0),
)

# do we need to calculate a new conversion matrix? Check correction matrix dict:
new_matrix = False
try:
old_db = self._correction_matrix_dict[lens_mode][kinetic_energy][pass_energy][
Expand All @@ -170,17 +212,23 @@ def convert_image(
e_correction,
jacobian_determinant,
) = calculate_matrix_correction(
lens_mode,
kinetic_energy,
pass_energy,
work_function,
binning,
self._config,
**kwds,
kinetic_energy=kinetic_energy,
pass_energy=pass_energy,
nx_pixels=img.shape[1],
ny_pixels=img.shape[0],
pixel_size=conversion_parameters["pixel_size"],
magnification=conversion_parameters["magnification"],
e_shift=conversion_parameters["e_shift"],
de1=conversion_parameters["de1"],
e_range=conversion_parameters["e_range"],
a_range=conversion_parameters["a_range"],
a_inner=conversion_parameters["a_inner"],
da_matrix=conversion_parameters["da_matrix"],
angle_offset_px=conversion_parameters["angle_offset_px"],
energy_offset_px=conversion_parameters["energy_offset_px"],
)

# save the config parameters for later use
# collect the info in a new nested dictionary
# save the config parameters for later use collect the info in a new nested dictionary
current_correction = {
lens_mode: {
kinetic_energy: {
Expand All @@ -198,16 +246,15 @@ def convert_image(
}

# add the new lens mode to the correction matrix dict
self._correction_matrix_dict = dict(
mergedicts(self._correction_matrix_dict, current_correction),
self._correction_matrix_dict = complete_dictionary(
self._correction_matrix_dict,
current_correction,
)

else:
old_matrix_check = True

# save a flag called old_matrix_check to determine if the current
# image was corrected using (True) or not using (False) the
# parameter in the class
# save a flag called old_matrix_check to determine if the current image was corrected using
# (True) or not using (False) the parameter in the class

self._correction_matrix_dict["old_matrix_check"] = old_matrix_check

Expand All @@ -218,20 +265,15 @@ def convert_image(
jacobian_determinant,
)

# TODO: annotate with metadata

if lens_mode in supported_angle_modes:
data_array = xr.DataArray(
data=conv_img,
coords={"Angle": angle_axis, "Ekin": ek_axis},
dims=["Angle", "Ekin"],
)
elif lens_mode in supported_space_modes:
data_array = xr.DataArray(
data=conv_img,
coords={"Position": angle_axis, "Ekin": ek_axis},
dims=["Position", "Ekin"],
)
data_array = xr.DataArray(
data=conv_img,
coords={
conversion_parameters["dims"][0]: angle_axis,
conversion_parameters["dims"][1]: ek_axis,
},
dims=conversion_parameters["dims"],
attrs={"conversion_parameters": conversion_parameters},
)

# Handle cropping based on parameters stored in correction dictionary
crop = kwds.pop("crop", self._config.get("crop", False))
Expand Down Expand Up @@ -523,34 +565,3 @@ def cropit(val): # pylint: disable=unused-argument
plt.show()
if apply:
cropit("")


def mergedicts(
dict1: dict,
dict2: dict,
) -> Generator[tuple[Any, Any], None, None]:
"""Merge two dictionaries, overwriting only existing values and retaining
previously present values
Args:
dict1 (dict): dictionary 1
dict2 (dict): dictionary 2
Yields:
dict: merged dictionary generator
"""
for k in set(dict1.keys()).union(dict2.keys()):
if k in dict1 and k in dict2:
if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
yield (k, dict(mergedicts(dict1[k], dict2[k])))
else:
# If one of the values is not a dict,
# you can't continue merging it.
# Value from second dict overrides one in first and we move on.
yield (k, dict2[k])
# Alternatively, replace this with exception
# raiser to alert you of value conflicts
elif k in dict1:
yield (k, dict1[k])
else:
yield (k, dict2[k])
5 changes: 2 additions & 3 deletions specsanalyzer/img_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


def gauss2d(
# pylint: disable=invalid-name, too-many-arguments
x: float | np.ndarray,
y: float | np.ndarray,
mx: float,
Expand Down Expand Up @@ -81,8 +80,8 @@ def fourier_filter_2d(
)
except KeyError as exc:
raise KeyError(
f"The peaks input is supposed to be a list of dicts with the\
following structure: pos_x, pos_y, sigma_x, sigma_y, amplitude. The error was {exc}.",
f"The peaks input is supposed to be a list of dicts with the "
"following structure: pos_x, pos_y, sigma_x, sigma_y, amplitude.",
) from exc

# apply mask to the FFT, and transform back
Expand Down
Loading

0 comments on commit 0671200

Please sign in to comment.