Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass kwd arguments to convert_image in the crop tool #34

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 108 additions & 69 deletions specsanalyzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def crop_tool(
- ek_range_max
- ang_range_min
- ang_range_max

Other parameters are passed to ``convert_image()``.
"""
data_array = self.convert_image(
raw_img=raw_img,
Expand All @@ -390,6 +392,7 @@ def crop_tool(
pass_energy=pass_energy,
work_function=work_function,
crop=False,
**kwds,
)

matplotlib.use("module://ipympl.backend_nbagg")
Expand Down Expand Up @@ -581,31 +584,29 @@ def fft_tool(
apply (bool, optional): Option to directly apply the settings. Defaults to False.
**kwds: Keyword arguments:

- fft_tool_params (dict): Dictionary of parameters for fft_tool, containing keys
`amplitude`: Normalized amplitude of subtraction
`pos_x`: horzontal spatial frequency of th mesh
`pos_y`: vertical spatial frequency of the mesh
`sigma_x`: horizontal frequency width
`sigma_y`: vertical frequency width
- `amplitude`: Normalized amplitude of subtraction
- `pos_x`: horzontal spatial frequency of th mesh
- `pos_y`: vertical spatial frequency of the mesh
- `sigma_x`: horizontal frequency width
- `sigma_y`: vertical frequency width
"""
matplotlib.use("module://ipympl.backend_nbagg")
try:
fft_tool_params = (
kwds["fft_tool_params"]
if "fft_tool_params" in kwds
else self._correction_matrix_dict["fft_tool_params"]
)
(amp, pos_x, pos_y, sig_x, sig_y) = (
fft_tool_params["amplitude"],
fft_tool_params["pos_x"],
fft_tool_params["pos_y"],
fft_tool_params["sigma_x"],
fft_tool_params["sigma_y"],
)
except KeyError:
(amp, pos_x, pos_y, sig_x, sig_y) = (0.95, 86, 116, 13, 22)
stored_parameters = self._correction_matrix_dict.get("fft_tool_params", {})
if not stored_parameters:
stored_parameters = {
"amplitude": 0.95,
"pos_x": 86,
"pos_y": 116,
"sigma_x": 13,
"sigma_y": 22,
}
amplitude = kwds.get("amplitude", stored_parameters["amplitude"])
pos_x = kwds.get("pos_x", stored_parameters["pos_x"])
pos_y = kwds.get("pos_y", stored_parameters["pos_y"])
sigma_x = kwds.get("sigma_x", stored_parameters["sigma_x"])
sigma_y = kwds.get("sigma_y", stored_parameters["sigma_y"])

fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
fft_filter_peaks = create_fft_params(amplitude, pos_x, pos_y, sigma_x, sigma_y)
try:
img = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="fft")
fft_filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered_fft")
Expand Down Expand Up @@ -646,37 +647,37 @@ def fft_tool(
ax5.legend()
# plt.tight_layout()

posx_slider = ipw.FloatSlider(
pos_x_slider = ipw.FloatSlider(
description="pos_x",
value=pos_x,
min=0,
max=128,
step=1,
)
posy_slider = ipw.FloatSlider(
pos_y_slider = ipw.FloatSlider(
description="pos_y",
value=pos_y,
min=0,
max=150,
step=1,
)
sigx_slider = ipw.FloatSlider(
sigma_x_slider = ipw.FloatSlider(
description="sig_x",
value=sig_x,
value=sigma_x,
min=0,
max=50,
step=1,
)
sigy_slider = ipw.FloatSlider(
sigma_y_slider = ipw.FloatSlider(
description="sig_y",
value=sig_y,
value=sigma_y,
min=0,
max=50,
step=1,
)
amp_slider = ipw.FloatSlider(
amplitude_slider = ipw.FloatSlider(
description="Amplitude",
value=amp,
value=amplitude,
min=0,
max=1,
step=0.01,
Expand All @@ -689,8 +690,8 @@ def fft_tool(
max=int(np.log10(np.abs(img).max())) + 1,
)

def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):
fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
def update(v_vals, pos_x, pos_y, sigma_x, sigma_y, amplitude):
fft_filter_peaks = create_fft_params(amplitude, pos_x, pos_y, sigma_x, sigma_y)
msk = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")
filtered_new = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")

Expand Down Expand Up @@ -718,39 +719,39 @@ def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):

ipw.interact(
update,
amp=amp_slider,
pos_x=posx_slider,
pos_y=posy_slider,
sig_x=sigx_slider,
sig_y=sigy_slider,
amplitude=amplitude_slider,
pos_x=pos_x_slider,
pos_y=pos_y_slider,
sigma_x=sigma_x_slider,
sigma_y=sigma_y_slider,
v_vals=clim_slider,
)

def apply_fft(apply: bool): # pylint: disable=unused-argument
amp = amp_slider.value
pos_x = posx_slider.value
pos_y = posy_slider.value
sig_x = sigx_slider.value
sig_y = sigy_slider.value
amplitude = amplitude_slider.value
pos_x = pos_x_slider.value
pos_y = pos_y_slider.value
sigma_x = sigma_x_slider.value
sigma_y = sigma_y_slider.value
self._correction_matrix_dict["fft_tool_params"] = {
"amplitude": amp,
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sig_x,
"sigma_y": sig_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
}
self.config["fft_filter_peaks"] = create_fft_params(
amp,
amplitude,
pos_x,
pos_y,
sig_x,
sig_y,
sigma_x,
sigma_y,
)
amp_slider.close()
posx_slider.close()
posy_slider.close()
sigx_slider.close()
sigy_slider.close()
amplitude_slider.close()
pos_x_slider.close()
pos_y_slider.close()
sigma_x_slider.close()
sigma_y_slider.close()
clim_slider.close()
apply_button.close()

Expand All @@ -762,25 +763,63 @@ def apply_fft(apply: bool): # pylint: disable=unused-argument
apply_fft(True)


def create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) -> list[dict]:
"""Function to create fft filter peaks list using the
provided Gaussian peak parameters. The peaks are defined
relative to each other such that they are periodically
aranged in a 256 x 150 Fourier space.
def create_fft_params(
amplitude: float,
pos_x: float,
pos_y: float,
sigma_x: float,
sigma_y: float,
) -> list[dict]:
"""Function to create fft filter peaks list using the provided Gaussian peak parameters.
The peaks are placed at +-x, y=0, and +-x, y=y, with width corresponding to the sigma
values.

Args:
amp: Gaussian peak amplitude
pos_x: x-position
pos_y: y-position
sig_x: FWHM in x-axis
sig_y: FWHM in y-axis
amplitude (float): Gaussian peak amplitude
pos_x (float): horizontal spatial frequency
pos_y (float): vertical spatial frequency
sigma_x (float): horizontal width
sigma_y (float): vertical width

Returns:
list[dict]: A list of the defined filter parameters
"""

fft_filter_peaks = [
{"amplitude": amp, "pos_x": -pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": 0, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": -pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{
"amplitude": amplitude,
"pos_x": -pos_x,
"pos_y": 0,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": 0,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": 0,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": -pos_x,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
]

return fft_filter_peaks
16 changes: 16 additions & 0 deletions specsscan/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def crop_tool(self, scan: int = None, path: Path | str = "", **kwds):
)

def fft_tool(self, scan: int = None, path: Path | str = "", **kwds):
"""FFT tool to play around with the peak parameters in the Fourier plane. Built to filter
out the meshgrid appearing in the raw data images. The optimized parameters are stored in
the class config dict under fft_filter_peaks.

Args:
scan (int, optional): Scan number to load. Defaults to the previously loaded scan.
path (Path | str): Path from where to load the data. Defaults to config value.
**kwds: Keyword arguments passed to ``SpecsAnalyzer.fft_tool()``:

- `apply`: Option to directly apply the settings.
- `amplitude`: Normalized amplitude of subtraction
- `pos_x`: horzontal spatial frequency of th mesh
- `pos_y`: vertical spatial frequency of the mesh
- `sigma_x`: horizontal frequency width
- `sigma_y`: vertical frequency width
"""
matplotlib.use("module://ipympl.backend_nbagg")
if scan is not None:
scan_path = get_scan_path(path, scan, self._config["data_path"])
Expand Down
8 changes: 6 additions & 2 deletions tests/test_specsscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

package_dir = os.path.dirname(specsscan.__file__)
test_dir = package_dir + "/../tests/data/"
fft_filter_peaks = create_fft_params(amp=1, pos_x=82, pos_y=116, sig_x=15, sig_y=23)
fft_filter_peaks = create_fft_params(amplitude=1, pos_x=82, pos_y=116, sigma_x=15, sigma_y=23)


def test_version():
Expand Down Expand Up @@ -289,7 +289,11 @@ def test_fft_tool():
np.testing.assert_almost_equal(res_xarray.data.sum(), 62197237155.50347, decimal=3)

sps.fft_tool(
fft_tool_params={"amplitude": 1, "pos_x": 82, "pos_y": 116, "sigma_x": 15, "sigma_y": 23},
amplitude=1,
pos_x=82,
pos_y=116,
sigma_x=15,
sigma_y=23,
apply=True,
)
assert sps.config["spa_params"]["fft_filter_peaks"] == fft_filter_peaks
Expand Down
12 changes: 5 additions & 7 deletions tutorial/2_specsscan_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,11 @@
"outputs": [],
"source": [
"sps.fft_tool(\n",
" fft_tool_params={\n",
" \"amplitude\": 1,\n",
" \"pos_x\": 82,\n",
" \"pos_y\": 116,\n",
" \"sigma_x\": 15,\n",
" \"sigma_y\": 23\n",
" },\n",
" amplitude=1,\n",
" pos_x=82,\n",
" pos_y=116,\n",
" sigma_x=15,\n",
" sigma_y=23,\n",
" apply=True # Use apply=False for interactive mode\n",
")"
]
Expand Down
Loading