Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT tool interface #27

Merged
merged 22 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions specsanalyzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,211 @@ def cropit(val): # pylint: disable=unused-argument
plt.show()
if apply:
cropit("")

def fft_tool(
self,
raw_image: np.ndarray,
apply: bool = False,
**kwds,
):
"""FFT tool to play around with the peak parameters in the Fourier plane.
Built to filter out the meshgrid appearing in the raw data images.
Args:
raw_image: A single 2-D data set.
"""
matplotlib.use("module://ipympl.backend_nbagg")
try:
fft_tool_params = (
kwds["fft_tool_params"]
if "fft_tool_params" in kwds
else self._correction_matrix_dict["fft_tool_params"]
)
(amp, pos_x, pos_y, sig_x, sig_y) = (
fft_tool_params["amplitude"],
fft_tool_params["pos_x"],
fft_tool_params["pos_y"],
fft_tool_params["sigma_x"],
fft_tool_params["sigma_y"],
)
except KeyError:
(amp, pos_x, pos_y, sig_x, sig_y) = (0.95, 86, 116, 13, 22)

fft_filter_peaks = create_fft_params_dict(amp, pos_x, pos_y, sig_x, sig_y)
try:
img = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="fft")
fft_filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered_fft")

mask = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")

filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")
except IndexError:
print("Load the scan first!")
raise

fig = plt.figure()
ax = fig.add_subplot(3, 2, 1)
im_fft = ax.imshow(np.abs(img).T, origin="lower", aspect="auto")
fig.colorbar(im_fft)

ax.set_title("FFT")
cont = ax.contour(mask.T)

# Plot raw image
ax2 = fig.add_subplot(3, 2, 2)
fft_filt = ax2.imshow(np.abs(fft_filtered).T, origin="lower", aspect="auto")
ax2.set_title("Filtered FFT")
fig.colorbar(fft_filt)

# Plot fft filtered image
ax3 = fig.add_subplot(2, 2, 3)
filt = ax3.imshow(filtered.T, origin="lower", aspect="auto")
ax3.set_title("Filtered Image")
fig.colorbar(filt)

ax4 = fig.add_subplot(3, 2, 4)
(edc,) = ax4.plot(np.sum(filtered, 0), label="EDC")
ax4.legend()

ax5 = fig.add_subplot(3, 2, 6)
(mdc,) = ax5.plot(np.sum(filtered, 1), label="MDC")
ax5.legend()
# plt.tight_layout()

posx_slider = ipw.FloatSlider(
description="pos_x",
value=pos_x,
min=0,
max=128,
step=1,
)
posy_slider = ipw.FloatSlider(
description="pos_y",
value=pos_y,
min=0,
max=150,
step=1,
)
sigx_slider = ipw.FloatSlider(
description="sig_x",
value=sig_x,
min=0,
max=200,
step=1,
)
sigy_slider = ipw.FloatSlider(
description="sig_y",
value=sig_y,
min=0,
max=200,
step=1,
)
amp_slider = ipw.FloatSlider(
description="Amplitude",
value=amp,
min=0,
max=1,
step=0.01,
)
clim_slider = ipw.FloatLogSlider(
description="colorbar limits",
value=1e4,
base=10,
min=-1,
max=int(np.log10(np.abs(img).max())) + 1,
)

def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):
fft_filter_peaks = create_fft_params_dict(amp, pos_x, pos_y, sig_x, sig_y)
msk = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")
filtered_new = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")

fft_filtered_new = fourier_filter_2d(
raw_image,
peaks=fft_filter_peaks,
ret="filtered_fft",
)

im_fft.set_clim(vmax=v_vals)
fft_filt.set_clim(vmax=v_vals)

filt.set_data(filtered_new.T)
fft_filt.set_data(np.abs(fft_filtered_new.T))

nonlocal cont
for i in range(len(cont.collections)):
cont.collections[i].remove()
cont = ax.contour(msk.T)

edc.set_ydata(np.sum(filtered_new, 0))
mdc.set_ydata(np.sum(filtered_new, 1))

fig.canvas.draw_idle()

ipw.interact(
update,
amp=amp_slider,
pos_x=posx_slider,
pos_y=posy_slider,
sig_x=sigx_slider,
sig_y=sigy_slider,
v_vals=clim_slider,
)

def apply_fft(apply: bool): # pylint: disable=unused-argument
amp = amp_slider.value
pos_x = posx_slider.value
pos_y = posy_slider.value
sig_x = sigx_slider.value
sig_y = sigy_slider.value
self._correction_matrix_dict["fft_tool_params"] = {
"amplitude": amp,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sig_x,
"sigma_y": sig_y,
}
self.config["fft_filter_peaks"] = create_fft_params_dict(
amp,
pos_x,
pos_y,
sig_x,
sig_y,
)
amp_slider.close()
posx_slider.close()
posy_slider.close()
sigx_slider.close()
sigy_slider.close()
clim_slider.close()
apply_button.close()

apply_button = ipw.Button(description="Apply")
display(apply_button)
apply_button.on_click(apply_fft)
plt.show()
if apply:
apply_fft(True)


def create_fft_params_dict(amp, posx, posy, sigx, sigy):
"""Function to create fft filter peaks dict using the
provided Gaussian peak parameters. The peaks are defined
relative to each other such that they are periodically
aranged in a 256 x 150 Fourier space.
Args:
amp: Gaussian peak amplitude
posx: x-position
posy: y-position
sigx: FWHM in x-axis
sigy: FWHM in y-axis
"""

fft_filter_peaks = [
{"amplitude": amp, "pos_x": -posx, "pos_y": 0, "sigma_x": sigx, "sigma_y": sigy},
{"amplitude": amp, "pos_x": posx, "pos_y": 0, "sigma_x": sigx, "sigma_y": sigy},
{"amplitude": amp, "pos_x": 0, "pos_y": posy, "sigma_x": sigx, "sigma_y": sigy},
{"amplitude": amp, "pos_x": -posx, "pos_y": posy, "sigma_x": sigx, "sigma_y": sigy},
{"amplitude": amp, "pos_x": posx, "pos_y": posy, "sigma_x": sigx, "sigma_y": sigy},
]

return fft_filter_peaks
6 changes: 4 additions & 2 deletions specsanalyzer/img_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def fourier_filter_2d(

# Do Fourier Transform of the (real-valued) image
image_fft = np.fft.rfft2(image)
# shift fft axis to have 0 in the center
image_fft = np.fft.fftshift(image_fft, axes=0)
mask = np.ones(image_fft.shape)
xgrid, ygrid = np.meshgrid(
range(image_fft.shape[0]),
Expand All @@ -73,7 +75,7 @@ def fourier_filter_2d(
mask -= peak["amplitude"] * gauss2d(
xgrid,
ygrid,
peak["pos_x"],
image_fft.shape[0] / 2 + peak["pos_x"],
peak["pos_y"],
peak["sigma_x"],
peak["sigma_y"],
Expand All @@ -85,7 +87,7 @@ def fourier_filter_2d(
) from exc

# apply mask to the FFT, and transform back
filtered = np.fft.irfft2(image_fft * mask)
filtered = np.fft.irfft2(np.fft.ifftshift(image_fft * mask, axes=0))
# strip negative values
filtered = filtered.clip(min=0)
if ret == "filtered":
Expand Down
13 changes: 4 additions & 9 deletions specsscan/config/example_config_FHI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ spa_params:
# sigma_x/sigma_y: the peak width (standard deviation) along each direction
fft_filter_peaks:
- amplitude: 1
pos_x: 79
pos_x: -49
pos_y: 0
sigma_x: 8
sigma_y: 8
- amplitude: 1
pos_x: 176
pos_x: 48
pos_y: 0
sigma_x: 8
sigma_y: 8
Expand All @@ -101,17 +101,12 @@ spa_params:
sigma_x: 5
sigma_y: 8
- amplitude: 1
pos_x: 78
pos_x: -50
pos_y: 109
sigma_x: 5
sigma_y: 5
- amplitude: 1
pos_x: 175
pos_x: 47
pos_y: 108
sigma_x: 5
sigma_y: 5
- amplitude: 1
pos_x: 254
pos_y: 109
sigma_x: 5
sigma_y: 8
21 changes: 21 additions & 0 deletions specsscan/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,27 @@ def crop_tool(self, scan: int = None, path: Path | str = "", **kwds):
**kwds,
)

def fft_tool(self, scan: int = None, path: Path | str = "", **kwds):
matplotlib.use("module://ipympl.backend_nbagg")
if scan is not None:
scan_path = get_scan_path(path, scan, self._config["data_path"])

data = load_images(
scan_path=scan_path,
tqdm_enable_nested=self._config["enable_nested_progress_bar"],
)
image = data[0]
else:
try:
image = self.metadata["loader"]["raw_data"][0]
except KeyError as exc:
raise ValueError("No image loaded, load image first!") from exc

self.spa.fft_tool(
image,
**kwds,
)

def check_scan(
self,
scan: int,
Expand Down
2 changes: 1 addition & 1 deletion tests/data
Submodule data updated from 9a1110 to b84798
9 changes: 7 additions & 2 deletions tests/test_specsscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,19 @@ def test_conversion_3d():
path=test_dir,
iterations=[0],
)
assert res_xarray.sum(axis=(0, 1, 2)) != res_xarray2.sum(axis=(0, 1, 2))
np.testing.assert_raises(
AssertionError,
np.testing.assert_allclose,
res_xarray.values,
res_xarray2.values,
)

res_xarray2 = sps.load_scan(
scan=4450,
path=test_dir,
iterations=np.s_[0:2],
)
assert res_xarray.sum(axis=(0, 1, 2)) == res_xarray2.sum(axis=(0, 1, 2))
np.testing.assert_allclose(res_xarray, res_xarray2)

with pytest.raises(IndexError):
sps.load_scan(
Expand Down
Loading