diff --git a/specsanalyzer/core.py b/specsanalyzer/core.py index 940aaf5..9acec64 100755 --- a/specsanalyzer/core.py +++ b/specsanalyzer/core.py @@ -565,3 +565,222 @@ def cropit(val): # pylint: disable=unused-argument plt.show() if apply: cropit("") + + def fft_tool( + self, + raw_image: np.ndarray, + apply: bool = False, + **kwds, + ): + """FFT tool to play around with the peak parameters in the Fourier plane. Built to filter + out the meshgrid appearing in the raw data images. The optimized parameters are stored in + the class config dict under fft_filter_peaks. + + Args: + raw_image (np.ndarray): The source image + apply (bool, optional): Option to directly apply the settings. Defaults to False. + **kwds: Keyword arguments: + + - fft_tool_params (dict): Dictionary of parameters for fft_tool, containing keys + `amplitude`: Normalized amplitude of subtraction + `pos_x`: horzontal spatial frequency of th mesh + `pos_y`: vertical spatial frequency of the mesh + `sigma_x`: horizontal frequency width + `sigma_y`: vertical frequency width + """ + matplotlib.use("module://ipympl.backend_nbagg") + try: + fft_tool_params = ( + kwds["fft_tool_params"] + if "fft_tool_params" in kwds + else self._correction_matrix_dict["fft_tool_params"] + ) + (amp, pos_x, pos_y, sig_x, sig_y) = ( + fft_tool_params["amplitude"], + fft_tool_params["pos_x"], + fft_tool_params["pos_y"], + fft_tool_params["sigma_x"], + fft_tool_params["sigma_y"], + ) + except KeyError: + (amp, pos_x, pos_y, sig_x, sig_y) = (0.95, 86, 116, 13, 22) + + fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) + try: + img = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="fft") + fft_filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered_fft") + + mask = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask") + + filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered") + except IndexError: + print("Load the scan first!") + raise + + fig = plt.figure() + ax = fig.add_subplot(3, 2, 1) + im_fft = ax.imshow(np.abs(img).T, origin="lower", aspect="auto") + fig.colorbar(im_fft) + + ax.set_title("FFT") + cont = ax.contour(mask.T) + + # Plot raw image + ax2 = fig.add_subplot(3, 2, 2) + fft_filt = ax2.imshow(np.abs(fft_filtered).T, origin="lower", aspect="auto") + ax2.set_title("Filtered FFT") + fig.colorbar(fft_filt) + + # Plot fft filtered image + ax3 = fig.add_subplot(2, 2, 3) + filt = ax3.imshow(filtered.T, origin="lower", aspect="auto") + ax3.set_title("Filtered Image") + fig.colorbar(filt) + + ax4 = fig.add_subplot(3, 2, 4) + (edc,) = ax4.plot(np.sum(filtered, 0), label="EDC") + ax4.legend() + + ax5 = fig.add_subplot(3, 2, 6) + (mdc,) = ax5.plot(np.sum(filtered, 1), label="MDC") + ax5.legend() + # plt.tight_layout() + + posx_slider = ipw.FloatSlider( + description="pos_x", + value=pos_x, + min=0, + max=128, + step=1, + ) + posy_slider = ipw.FloatSlider( + description="pos_y", + value=pos_y, + min=0, + max=150, + step=1, + ) + sigx_slider = ipw.FloatSlider( + description="sig_x", + value=sig_x, + min=0, + max=50, + step=1, + ) + sigy_slider = ipw.FloatSlider( + description="sig_y", + value=sig_y, + min=0, + max=50, + step=1, + ) + amp_slider = ipw.FloatSlider( + description="Amplitude", + value=amp, + min=0, + max=1, + step=0.01, + ) + clim_slider = ipw.FloatLogSlider( + description="colorbar limits", + value=int(np.abs(img).max() / 500), + base=10, + min=-1, + max=int(np.log10(np.abs(img).max())) + 1, + ) + + def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp): + fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) + msk = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask") + filtered_new = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered") + + fft_filtered_new = fourier_filter_2d( + raw_image, + peaks=fft_filter_peaks, + ret="filtered_fft", + ) + + im_fft.set_clim(vmax=v_vals) + fft_filt.set_clim(vmax=v_vals) + + filt.set_data(filtered_new.T) + fft_filt.set_data(np.abs(fft_filtered_new.T)) + + nonlocal cont + for i in range(len(cont.collections)): + cont.collections[i].remove() + cont = ax.contour(msk.T) + + edc.set_ydata(np.sum(filtered_new, 0)) + mdc.set_ydata(np.sum(filtered_new, 1)) + + fig.canvas.draw_idle() + + ipw.interact( + update, + amp=amp_slider, + pos_x=posx_slider, + pos_y=posy_slider, + sig_x=sigx_slider, + sig_y=sigy_slider, + v_vals=clim_slider, + ) + + def apply_fft(apply: bool): # pylint: disable=unused-argument + amp = amp_slider.value + pos_x = posx_slider.value + pos_y = posy_slider.value + sig_x = sigx_slider.value + sig_y = sigy_slider.value + self._correction_matrix_dict["fft_tool_params"] = { + "amplitude": amp, + "pos_x": pos_x, + "pos_y": pos_y, + "sigma_x": sig_x, + "sigma_y": sig_y, + } + self.config["fft_filter_peaks"] = create_fft_params( + amp, + pos_x, + pos_y, + sig_x, + sig_y, + ) + amp_slider.close() + posx_slider.close() + posy_slider.close() + sigx_slider.close() + sigy_slider.close() + clim_slider.close() + apply_button.close() + + apply_button = ipw.Button(description="Apply") + display(apply_button) + apply_button.on_click(apply_fft) + plt.show() + if apply: + apply_fft(True) + + +def create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) -> list[dict]: + """Function to create fft filter peaks list using the + provided Gaussian peak parameters. The peaks are defined + relative to each other such that they are periodically + aranged in a 256 x 150 Fourier space. + Args: + amp: Gaussian peak amplitude + pos_x: x-position + pos_y: y-position + sig_x: FWHM in x-axis + sig_y: FWHM in y-axis + """ + + fft_filter_peaks = [ + {"amplitude": amp, "pos_x": -pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y}, + {"amplitude": amp, "pos_x": pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y}, + {"amplitude": amp, "pos_x": 0, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y}, + {"amplitude": amp, "pos_x": -pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y}, + {"amplitude": amp, "pos_x": pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y}, + ] + + return fft_filter_peaks diff --git a/specsanalyzer/img_tools.py b/specsanalyzer/img_tools.py index d8b9a40..2a940c2 100755 --- a/specsanalyzer/img_tools.py +++ b/specsanalyzer/img_tools.py @@ -61,6 +61,8 @@ def fourier_filter_2d( # Do Fourier Transform of the (real-valued) image image_fft = np.fft.rfft2(image) + # shift fft axis to have 0 in the center + image_fft = np.fft.fftshift(image_fft, axes=0) mask = np.ones(image_fft.shape) xgrid, ygrid = np.meshgrid( range(image_fft.shape[0]), @@ -73,7 +75,7 @@ def fourier_filter_2d( mask -= peak["amplitude"] * gauss2d( xgrid, ygrid, - peak["pos_x"], + image_fft.shape[0] / 2 + peak["pos_x"], peak["pos_y"], peak["sigma_x"], peak["sigma_y"], @@ -85,7 +87,7 @@ def fourier_filter_2d( ) from exc # apply mask to the FFT, and transform back - filtered = np.fft.irfft2(image_fft * mask) + filtered = np.fft.irfft2(np.fft.ifftshift(image_fft * mask, axes=0)) # strip negative values filtered = filtered.clip(min=0) if ret == "filtered": diff --git a/specsscan/config/example_config_FHI.yaml b/specsscan/config/example_config_FHI.yaml index 12405c6..bcd7398 100644 --- a/specsscan/config/example_config_FHI.yaml +++ b/specsscan/config/example_config_FHI.yaml @@ -91,7 +91,7 @@ spa_params: sigma_x: 8 sigma_y: 8 - amplitude: 1 - pos_x: 176 + pos_x: -80 pos_y: 0 sigma_x: 8 sigma_y: 8 @@ -106,12 +106,7 @@ spa_params: sigma_x: 5 sigma_y: 5 - amplitude: 1 - pos_x: 175 + pos_x: -81 pos_y: 108 sigma_x: 5 sigma_y: 5 - - amplitude: 1 - pos_x: 254 - pos_y: 109 - sigma_x: 5 - sigma_y: 8 diff --git a/specsscan/core.py b/specsscan/core.py index f1ca599..2f522e0 100755 --- a/specsscan/core.py +++ b/specsscan/core.py @@ -224,6 +224,9 @@ def load_scan( else: res_xarray = res_xarray.transpose("Angle", "Ekin", dim) + slow_axes = {dim} if dim else set() + fast_axes = set(res_xarray.dims) - slow_axes + projection = "reciprocal" if "Angle" in fast_axes else "real" conversion_metadata = res_xarray.attrs.pop("conversion_parameters") # rename coords and store mapping information, if available @@ -238,6 +241,13 @@ def load_scan( if k in res_xarray.dims } res_xarray = res_xarray.rename(rename_dict) + for k, v in coordinate_mapping.items(): + if k in fast_axes: + fast_axes.remove(k) + fast_axes.add(v) + if k in slow_axes: + slow_axes.remove(k) + slow_axes.add(v) self._scan_info["coordinate_depends"] = depends_dict axis_dict = { @@ -262,8 +272,9 @@ def load_scan( df_lut, self._scan_info, self.config, - fast_axis="Angle" if "Angle" in res_xarray.dims else "Position", - slow_axis=dim, + fast_axes=list(fast_axes), # type: ignore + slow_axes=list(slow_axes), + projection=projection, metadata=copy.deepcopy(metadata), collect_metadata=collect_metadata, ), @@ -312,6 +323,27 @@ def crop_tool(self, scan: int = None, path: Path | str = "", **kwds): **kwds, ) + def fft_tool(self, scan: int = None, path: Path | str = "", **kwds): + matplotlib.use("module://ipympl.backend_nbagg") + if scan is not None: + scan_path = get_scan_path(path, scan, self._config["data_path"]) + + data = load_images( + scan_path=scan_path, + tqdm_enable_nested=self._config["enable_nested_progress_bar"], + ) + image = data[0] + else: + try: + image = self.metadata["loader"]["raw_data"][0] + except KeyError as exc: + raise ValueError("No image loaded, load image first!") from exc + + self.spa.fft_tool( + image, + **kwds, + ) + def check_scan( self, scan: int, @@ -411,13 +443,18 @@ def check_scan( except KeyError: pass + slow_axes = {"Iteration"} + fast_axes = set(res_xarray.dims) - slow_axes + projection = "reciprocal" if "Angle" in fast_axes else "real" + self.metadata.update( **handle_meta( df_lut, self._scan_info, self.config, - fast_axis="Angle" if "Angle" in res_xarray.dims else "Position", - slow_axis=dims[1], + fast_axes=list(fast_axes), # type: ignore + slow_axes=list(slow_axes), + projection=projection, metadata=metadata, collect_metadata=collect_metadata, ), diff --git a/specsscan/helpers.py b/specsscan/helpers.py index 5b0e4c2..174da88 100644 --- a/specsscan/helpers.py +++ b/specsscan/helpers.py @@ -348,8 +348,9 @@ def handle_meta( df_lut: pd.DataFrame, scan_info: dict, config: dict, - fast_axis: str, - slow_axis: str, + fast_axes: list[str], + slow_axes: list[str], + projection: str, metadata: dict = None, collect_metadata: bool = False, ) -> dict: @@ -360,8 +361,8 @@ def handle_meta( from ``parse_lut_to_df()`` scan_info (dict): scan_info class dict containing containing the contents of info.txt file config (dict): config dictionary containing the contents of config.yaml file - fast_axis (str): The fast-axis dimension of the scan - slow_axis (str): The slow-axis dimension of the scan + fast_axes (list[str]): The fast-axis dimensions of the scan + slow_axes (list[str]): The slow-axis dimensions of the scan metadata (dict, optional): Metadata dictionary with additional metadata for the scan. Defaults to empty dictionary. collect_metadata (bool, optional): Option to collect further metadata e.g. from EPICS @@ -470,14 +471,13 @@ def handle_meta( metadata["scan_info"]["energy_scan_mode"] = energy_scan_mode - projection = "reciprocal" if fast_axis in {"Anlge", "angular0", "angular1"} else "real" metadata["scan_info"]["projection"] = projection metadata["scan_info"]["scheme"] = ( "angular dispersive" if projection == "reciprocal" else "spatial dispersive" ) - metadata["scan_info"]["slow_axes"] = slow_axis - metadata["scan_info"]["fast_axes"] = ["Ekin", fast_axis] + metadata["scan_info"]["slow_axes"] = slow_axes + metadata["scan_info"]["fast_axes"] = fast_axes print("Done!") diff --git a/tests/data b/tests/data index 9a11106..a6ef660 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 9a11106b902dbf1089de50813eae649ccf6e3a83 +Subproject commit a6ef660ee7032f1d474d30f12299e3bbd417a5c4 diff --git a/tests/test_specsscan.py b/tests/test_specsscan.py index 3b77349..f27d1ae 100755 --- a/tests/test_specsscan.py +++ b/tests/test_specsscan.py @@ -6,11 +6,13 @@ import pytest import specsscan +from specsanalyzer.core import create_fft_params from specsscan import __version__ from specsscan import SpecsScan package_dir = os.path.dirname(specsscan.__file__) test_dir = package_dir + "/../tests/data/" +fft_filter_peaks = create_fft_params(amp=1, pos_x=82, pos_y=116, sig_x=15, sig_y=23) def test_version(): @@ -65,14 +67,19 @@ def test_conversion_3d(): path=test_dir, iterations=[0], ) - assert res_xarray.sum(axis=(0, 1, 2)) != res_xarray2.sum(axis=(0, 1, 2)) + np.testing.assert_raises( + AssertionError, + np.testing.assert_allclose, + res_xarray.values, + res_xarray2.values, + ) res_xarray2 = sps.load_scan( scan=4450, path=test_dir, iterations=np.s_[0:2], ) - assert res_xarray.sum(axis=(0, 1, 2)) == res_xarray2.sum(axis=(0, 1, 2)) + np.testing.assert_allclose(res_xarray, res_xarray2) with pytest.raises(IndexError): sps.load_scan( @@ -257,6 +264,40 @@ def test_crop_tool(): assert res_xarray.Ekin[-1] == 22.826511627906974 +def test_fft_tool(): + """Test the fft tool""" + + sps = SpecsScan( + config=test_dir + "config.yaml", + user_config={}, + system_config={}, + ) + res_xarray = sps.load_scan( + scan=3610, + path=test_dir, + apply_fft_filter=False, + ) + + np.testing.assert_almost_equal(res_xarray.data.sum(), 62145561928.15108, decimal=3) + + res_xarray = sps.load_scan( + scan=3610, + path=test_dir, + fft_filter_peaks=fft_filter_peaks, + apply_fft_filter=True, + ) + np.testing.assert_almost_equal(res_xarray.data.sum(), 62197237155.50347, decimal=3) + + sps.fft_tool( + fft_tool_params={"amplitude": 1, "pos_x": 82, "pos_y": 116, "sigma_x": 15, "sigma_y": 23}, + apply=True, + ) + assert sps.config["spa_params"]["fft_filter_peaks"] == fft_filter_peaks + assert sps.spa.config["fft_filter_peaks"] == fft_filter_peaks + res_xarray = sps.load_scan(scan=3610, path=test_dir, apply_fft_filter=True) + np.testing.assert_almost_equal(res_xarray.data.sum(), 62197237155.50347, decimal=3) + + def test_conversion_and_save_to_nexus(): """Test the conversion of a tilt scan and saving as NeXus""" config = {"nexus": {"input_files": [package_dir + "/config/NXmpes_arpes_config.json"]}} diff --git a/tutorial/1_specsanalyzer_conversion_examples.ipynb b/tutorial/1_specsanalyzer_conversion_examples.ipynb index 5472e6d..a2d9ec7 100644 --- a/tutorial/1_specsanalyzer_conversion_examples.ipynb +++ b/tutorial/1_specsanalyzer_conversion_examples.ipynb @@ -169,6 +169,57 @@ "res_xarray.plot()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, one can use the interactive fft tool to optimize the fft peak positions of the grid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spa.fft_tool(tsv_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The peak parameters are stored in the config dict which can be passed as kwds to the convert_image function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fft_filter_peaks = spa.config['fft_filter_peaks']\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res_xarray = spa.convert_image(\n", + " tsv_data,\n", + " lens_mode,\n", + " kinetic_energy,\n", + " pass_energy,\n", + " work_function,\n", + " apply_fft_filter=True,\n", + " fft_filter_peaks=fft_filter_peaks\n", + ")\n", + "plt.figure()\n", + "res_xarray.plot()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/tutorial/2_specsscan_example.ipynb b/tutorial/2_specsscan_example.ipynb index cab131c..0f1ac21 100644 --- a/tutorial/2_specsscan_example.ipynb +++ b/tutorial/2_specsscan_example.ipynb @@ -184,6 +184,61 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Removal of Mesh Artifact" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to remove the meshgrid artifact present in the data, an fft filtering is applied already in the data loaded previously. For this, parameters of the fft peaks corresponding to the grid are required which can be provided in the config file. Alternatively, one can also interactively optimize the parameters using the fft tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sps.fft_tool(\n", + " fft_tool_params={\n", + " \"amplitude\": 1,\n", + " \"pos_x\": 82,\n", + " \"pos_y\": 116,\n", + " \"sigma_x\": 15,\n", + " \"sigma_y\": 23\n", + " },\n", + " apply=True # Use apply=False for interactive mode\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the scan again for the new changes to apply to all the images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res_xarray = sps.load_scan(\n", + " scan=4450,\n", + " path=path,\n", + " apply_fft_filter=True\n", + ")\n", + "\n", + "plt.figure()\n", + "res_xarray[:,:,0].plot()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -200,7 +255,7 @@ "outputs": [], "source": [ "plt.figure()\n", - "sps.result.loc[{\"Angle\": slice(-5, 5)}].sum(axis=0).plot()" + "sps.result.loc[{\"angular1\": slice(-5, 5)}].sum(axis=0).plot()" ] }, { @@ -317,7 +372,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.0" }, "vscode": { "interpreter": {