From bdb5a48a9fc9bcbb62bff8ecabe51aa2e81aa3e2 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Mon, 2 Jan 2023 00:53:31 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20coherent=20artifact=20plot=20?= =?UTF-8?q?functionality=20(#123)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Added abs_max helper function * ♻️ Refactored add_cycler_if_not_none to also work on multiple axes * ✨ Added plot_coherent_artifact function * 🧹 Apply suggestions from code review * 👌 Changed default of normalize in plot_coherent_artifact to True (as requested by @ism200) Co-authored-by: Joris Snellenburg --- pyglotaran_extras/__init__.py | 2 + .../plotting/plot_coherent_artifact.py | 123 ++++++++++++++++++ pyglotaran_extras/plotting/utils.py | 39 +++++- tests/plotting/test_utils.py | 46 ++++++- 4 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 pyglotaran_extras/plotting/plot_coherent_artifact.py diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index bd35d74c..009a7f51 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -1,6 +1,7 @@ """Pyglotaran extension package with convenience functionality such as plotting.""" from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.io.setup_case_study import setup_case_study +from pyglotaran_extras.plotting.plot_coherent_artifact import plot_coherent_artifact from pyglotaran_extras.plotting.plot_data import plot_data_overview from pyglotaran_extras.plotting.plot_guidance import plot_guidance from pyglotaran_extras.plotting.plot_irf_dispersion_center import plot_irf_dispersion_center @@ -12,6 +13,7 @@ __all__ = [ "load_data", "setup_case_study", + "plot_coherent_artifact", "plot_data_overview", "plot_overview", "plot_simple_overview", diff --git a/pyglotaran_extras/plotting/plot_coherent_artifact.py b/pyglotaran_extras/plotting/plot_coherent_artifact.py new file mode 100644 index 00000000..74ca6194 --- /dev/null +++ b/pyglotaran_extras/plotting/plot_coherent_artifact.py @@ -0,0 +1,123 @@ +"""Module containing coherent artifact plot functionality.""" +from __future__ import annotations + +from typing import TYPE_CHECKING +from warnings import warn + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cycler import Cycler + +from pyglotaran_extras.plotting.utils import abs_max +from pyglotaran_extras.plotting.utils import add_cycler_if_not_none + +if TYPE_CHECKING: + from matplotlib.figure import Figure + from matplotlib.pyplot import Axes + + +def plot_coherent_artifact( + res: xr.Dataset, + *, + time_range: tuple[float, float] | None = None, + spectral: float = 0, + normalize: bool = True, + figsize: tuple[int, int] = (18, 7), + show_zero_line: bool = True, + cycler: Cycler | None = None, + title: str | None = "Coherent Artifact", +) -> tuple[Figure, Axes]: + """Plot coherent artifact as IRF derivative components over time and IRFAS over spectral dim. + + The IRFAS are the IRF (Instrument Response Function) Associated Spectra. + + Parameters + ---------- + res: xr.Dataset + Result dataset from a pyglotaran optimization. + time_range: tuple[float, float] | None + Start and end time for the IRF derivative plot. Defaults to None which means that + the full time range is used. + spectral: float + Value of the spectral axis that should be used to select the data for the IRF derivative + plot this value does not need to be an exact existing value and only has effect if the + IRF has dispersion. Defaults to 0 which means that the IRF derivative plot at lowest + spectral value will be shown. + normalize: bool + Whether or not to normalize the IRF derivative plot. If the IRF derivative is normalized, + the IRFAS is scaled with the reciprocal of the normalization to compensate for this. + Defaults to True. + figsize: tuple[int, int] + Size of the figure (N, M) in inches. Defaults to (18, 7). + show_zero_line: bool + Whether or not to add a horizontal line at zero. Defaults to True. + cycler: Cycler | None + Plot style cycler to use. Defaults to None, which means that the matplotlib default style + will be used. + title: str | None + Title of the figure. Defaults to "Coherent Artifact". + + Returns + ------- + tuple[Figure, Axes] + Figure object which contains the plots and the Axes. + """ + fig, axes = plt.subplots(1, 2, figsize=figsize) + add_cycler_if_not_none(axes, cycler) + + if ( + "coherent_artifact_response" not in res + or "coherent_artifact_associated_spectra" not in res + ): + warn( + UserWarning(f"Dataset does not contain coherent artifact data:\n {res.data_vars}"), + stacklevel=2, + ) + return fig, axes + + irf_max = abs_max(res.coherent_artifact_response, result_dims=("coherent_artifact_order")) + irfas_max = abs_max( + res.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order") + ) + scales = np.sqrt(irfas_max * irf_max) + norm_factor = 1 + irf_y_label = "amplitude" + irfas_y_label = "ΔA" + + if normalize is True: + norm_factor = scales.max() + irf_y_label = f"normalized {irf_y_label}" + + plot_slice_irf = ( + res.coherent_artifact_response.sel(spectral=spectral, method="nearest") + / irf_max + * scales + / norm_factor + ) + irf_sel_kwargs = ( + {"time": slice(time_range[0], time_range[1])} if time_range is not None else {} + ) + plot_slice_irf.sel(**irf_sel_kwargs).plot.line(x="time", ax=axes[0]) + axes[0].set_title("IRF Derivatives") + axes[0].set_ylabel(f"{irf_y_label} (a.u.)") + + plot_slice_irfas = res.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor + plot_slice_irfas.plot.line(x="spectral", ax=axes[1]) + axes[1].get_legend().remove() + axes[1].set_title("IRFAS") + axes[1].set_ylabel(f"{irfas_y_label} (mOD)") + + if show_zero_line is True: + axes[0].axhline(0, color="k", linewidth=1) + axes[1].axhline(0, color="k", linewidth=1) + + # + if res.coords["coherent_artifact_order"][0] == 1: + axes[0].legend( + [f"{int(ax_label)-1}" for ax_label in res.coords["coherent_artifact_order"]], + title="coherent_artifact_order", + ) + if title: + fig.suptitle(title, fontsize=16) + return fig, axes diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 2f92b707..8139db46 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Iterable from warnings import warn import numpy as np @@ -10,7 +11,7 @@ from pyglotaran_extras.io.utils import result_dataset_mapping if TYPE_CHECKING: - from typing import Iterable + from typing import Hashable from cycler import Cycler from matplotlib.axis import Axis @@ -360,7 +361,7 @@ def get_shifted_traces( return shift_time_axis_by_irf_location(traces, irf_location) -def add_cycler_if_not_none(axis: Axis, cycler: Cycler | None) -> None: +def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None: """Add cycler to and axis if it is not None. This is a convenience function that allow to opt out of using @@ -370,10 +371,38 @@ def add_cycler_if_not_none(axis: Axis, cycler: Cycler | None) -> None: Parameters ---------- - axis: Axis - Axis to plot the data and fits on. + axis: Axis | Axes + Axis to plot on. cycler: Cycler | None Plot style cycler to use. """ if cycler is not None: - axis.set_prop_cycle(cycler) + # We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes` + if hasattr(axis, "flatten") is False: + axis = np.array([axis]) + for ax in axis.flatten(): + ax.set_prop_cycle(cycler) + + +def abs_max( + data: xr.DataArray, *, result_dims: Hashable | Iterable[Hashable] = () +) -> xr.DataArray: + """Calculate the absolute maximum values of ``data`` along all dims except ``result_dims``. + + Parameters + ---------- + data: xr.DataArray + Data for which the absolute maximum should be calculated. + result_dims: Hashable | Iterable[Hashable] + Dimensions of ``data`` which should be preserved and part of the resulting DataArray. + Defaults to () which results in using the absolute maximum of all values. + + Returns + ------- + xr.DataArray + Absolute maximum values of ``data`` with dimensions ``result_dims``. + """ + if not isinstance(result_dims, Iterable): + result_dims = (result_dims,) + reduce_dims = (dim for dim in data.dims if dim not in result_dims) + return np.abs(data).max(dim=reduce_dims) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index 9119cc37..c5681207 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -1,13 +1,18 @@ """Tests for pyglotaran_extras.plotting.utils""" from __future__ import annotations +from typing import Hashable +from typing import Iterable + import matplotlib import matplotlib.pyplot as plt import pytest +import xarray as xr from cycler import Cycler from cycler import cycle from pyglotaran_extras.plotting.style import PlotStyle +from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none matplotlib.use("Agg") @@ -18,10 +23,45 @@ "cycler,expected_cycler", ((None, DEFAULT_CYCLER()), (PlotStyle().cycler, PlotStyle().cycler())), ) -def test_add_cycler_if_not_none(cycler: Cycler | None, expected_cycler: cycle): - """Default cycler inf None and cycler otherwise""" +def test_add_cycler_if_not_none_single_axis(cycler: Cycler | None, expected_cycler: cycle): + """Default cycler if None and cycler otherwise on a single axis""" ax = plt.subplot() add_cycler_if_not_none(ax, cycler) for _ in range(10): - assert next(ax._get_lines.prop_cycler) == next(expected_cycler) + expected = next(expected_cycler) + assert next(ax._get_lines.prop_cycler) == expected + + +@pytest.mark.parametrize( + "cycler,expected_cycler", + ((None, DEFAULT_CYCLER()), (PlotStyle().cycler, PlotStyle().cycler())), +) +def test_add_cycler_if_not_none_multiple_axes(cycler: Cycler | None, expected_cycler: cycle): + """Default cycler if None and cycler otherwise on all axes""" + _, axes = plt.subplots(1, 2) + add_cycler_if_not_none(axes, cycler) + + for _ in range(10): + expected = next(expected_cycler) + assert next(axes[0]._get_lines.prop_cycler) == expected + assert next(axes[1]._get_lines.prop_cycler) == expected + + +@pytest.mark.parametrize( + "result_dims, expected", + ( + ((), xr.DataArray(40)), + ("dim1", xr.DataArray([20, 40], coords={"dim1": [1, 2]})), + ("dim2", xr.DataArray([30, 40], coords={"dim2": [3, 4]})), + (("dim1",), xr.DataArray([20, 40], coords={"dim1": [1, 2]})), + ( + ("dim1", "dim2"), + xr.DataArray([[10, 20], [30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]}), + ), + ), +) +def test_abs_max(result_dims: Hashable | Iterable[Hashable], expected: xr.DataArray): + """Result values are positive and dimensions are preserved if result_dims is not empty.""" + data = xr.DataArray([[-10, 20], [-30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]}) + assert abs_max(data, result_dims=result_dims).equals(expected)