From f4ea1577e6ddd96c9e38d75107e26869344a01b2 Mon Sep 17 00:00:00 2001 From: s-weigand Date: Fri, 25 Aug 2023 00:17:37 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20Fix=20crash=20when=20plotting=20?= =?UTF-8?q?spectral=20model=20result?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../plotting/plot_coherent_artifact.py | 4 ++- pyglotaran_extras/plotting/plot_data.py | 2 +- pyglotaran_extras/plotting/plot_doas.py | 2 +- pyglotaran_extras/plotting/plot_residual.py | 2 +- pyglotaran_extras/plotting/plot_svd.py | 28 ++++++++++++++----- pyglotaran_extras/plotting/utils.py | 8 ++++-- 6 files changed, 32 insertions(+), 14 deletions(-) diff --git a/pyglotaran_extras/plotting/plot_coherent_artifact.py b/pyglotaran_extras/plotting/plot_coherent_artifact.py index 6d8cb24e..2a935299 100644 --- a/pyglotaran_extras/plotting/plot_coherent_artifact.py +++ b/pyglotaran_extras/plotting/plot_coherent_artifact.py @@ -88,7 +88,9 @@ def plot_coherent_artifact( return fig, axes irf_location = extract_irf_location(dataset, spectral, main_irf_nr) - irf_data = shift_time_axis_by_irf_location(dataset.coherent_artifact_response, irf_location) + irf_data = shift_time_axis_by_irf_location( + dataset.coherent_artifact_response, irf_location, _internal_call=True + ) irf_max = abs_max(irf_data, result_dims=("coherent_artifact_order")) irfas_max = abs_max( diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index da3505a5..f4ce96a4 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -68,7 +68,7 @@ def plot_data_overview( Figure and axes which can then be refined by the user. """ dataset = load_data(dataset, _stacklevel=3) - data = shift_time_axis_by_irf_location(dataset.data, irf_location) + data = shift_time_axis_by_irf_location(dataset.data, irf_location, _internal_call=True) if len(not_single_element_dims(data)) == 1: return _plot_single_trace( diff --git a/pyglotaran_extras/plotting/plot_doas.py b/pyglotaran_extras/plotting/plot_doas.py index d17bc6e1..dc1b673b 100644 --- a/pyglotaran_extras/plotting/plot_doas.py +++ b/pyglotaran_extras/plotting/plot_doas.py @@ -111,7 +111,7 @@ def plot_doas( oscillations = oscillations.sel(spectral=spectral, method="nearest") oscillations = shift_time_axis_by_irf_location( - oscillations.sel(**osc_sel_kwargs), irf_location + oscillations.sel(**osc_sel_kwargs), irf_location, _internal_call=True ) oscillations_spectra = dataset["damped_oscillation_associated_spectra"].sel(**osc_sel_kwargs) diff --git a/pyglotaran_extras/plotting/plot_residual.py b/pyglotaran_extras/plotting/plot_residual.py index 9fe0fe72..dd6d85a8 100644 --- a/pyglotaran_extras/plotting/plot_residual.py +++ b/pyglotaran_extras/plotting/plot_residual.py @@ -65,7 +65,7 @@ def plot_residual( add_cycler_if_not_none(ax, cycler) data = res.data if show_data else res.residual - data = shift_time_axis_by_irf_location(data, irf_location) + data = shift_time_axis_by_irf_location(data, irf_location, _internal_call=True) title = "dataset" if show_data else "residual" shape = np.array(data.shape) # Handle different dimensionality of data diff --git a/pyglotaran_extras/plotting/plot_svd.py b/pyglotaran_extras/plotting/plot_svd.py index 4805f86b..ffb167bc 100644 --- a/pyglotaran_extras/plotting/plot_svd.py +++ b/pyglotaran_extras/plotting/plot_svd.py @@ -78,6 +78,7 @@ def plot_svd( cycler=cycler, indices=range(nr_of_residual_svd_vectors), show_legend=show_residual_svd_legend, + irf_location=irf_location, ) plot_sv_residual(res, axes[0, 2], cycler=cycler) add_svd_to_dataset(dataset=res, name="data") @@ -97,6 +98,7 @@ def plot_svd( cycler=cycler, indices=range(nr_of_data_svd_vectors), show_legend=show_data_svd_legend, + irf_location=irf_location, ) plot_sv_data(res, axes[1, 2], cycler=cycler) @@ -136,8 +138,7 @@ def plot_lsv_data( """ add_cycler_if_not_none(ax, cycler) dLSV = res.data_left_singular_vectors # noqa: N806 - dLSV = shift_time_axis_by_irf_location(dLSV, irf_location) # noqa: N806 - _plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend) + _plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend, irf_location) ax.set_title("data. LSV") if linlog: ax.set_xscale("symlog", linthresh=linthresh) @@ -150,6 +151,7 @@ def plot_rsv_data( indices: Sequence[int] = range(4), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, + irf_location: float | None = None, ) -> None: """Plot right singular vectors (spectra) of the data matrix. @@ -165,10 +167,13 @@ def plot_rsv_data( Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend : bool Whether or not to show the legend. Defaults to True. + irf_location : float | None + Location of the ``irf`` by which the time axis will get shifted. If it is None the time + axis will not be shifted. Defaults to None. """ add_cycler_if_not_none(ax, cycler) dRSV = res.data_right_singular_vectors # noqa: N806 - _plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend) + _plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend, irf_location) ax.set_title("data. RSV") @@ -237,8 +242,7 @@ def plot_lsv_residual( rLSV = res.weighted_residual_left_singular_vectors # noqa: N806 else: rLSV = res.residual_left_singular_vectors # noqa: N806 - rLSV = shift_time_axis_by_irf_location(rLSV, irf_location) # noqa: N806 - _plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend) + _plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend, irf_location) ax.set_title("res. LSV") if linlog: ax.set_xscale("symlog", linthresh=linthresh) @@ -251,6 +255,7 @@ def plot_rsv_residual( indices: Sequence[int] = range(2), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, + irf_location: float | None = None, ) -> None: """Plot right singular vectors (spectra) of the residual matrix. @@ -266,13 +271,16 @@ def plot_rsv_residual( Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend : bool Whether or not to show the legend. Defaults to True. + irf_location : float | None + Location of the ``irf`` by which the time axis will get shifted. If it is None the time + axis will not be shifted. Defaults to None. """ add_cycler_if_not_none(ax, cycler) if "weighted_residual_right_singular_vectors" in res: rRSV = res.weighted_residual_right_singular_vectors # noqa: N806 else: rRSV = res.residual_right_singular_vectors # noqa: N806 - _plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend) + _plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend, irf_location) ax.set_title("res. RSV") @@ -312,6 +320,7 @@ def _plot_svd_vectors( sv_index_dim: str, ax: Axis, show_legend: bool, + irf_location: float | None, ) -> None: """Plot SVD vectors with decreasing zorder on axis ``ax``. @@ -327,6 +336,9 @@ def _plot_svd_vectors( Axis to plot on. show_legend : bool Whether or not to show the legend. + irf_location : float | None + Location of the ``irf`` by which the time axis will get shifted. If it is None the time + axis will not be shifted. Defaults to None. See Also -------- @@ -336,7 +348,9 @@ def _plot_svd_vectors( plot_rsv_residual """ max_index = len(getattr(vector_data, sv_index_dim)) - values = vector_data.isel(**{sv_index_dim: indices[:max_index]}) + values = shift_time_axis_by_irf_location( + vector_data.isel(**{sv_index_dim: indices[:max_index]}), irf_location, _internal_call=True + ) x_dim = vector_data.dims[1] if x_dim == sv_index_dim: values = values.T diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index dd501309..5424e8a5 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -307,8 +307,7 @@ def extract_dataset_scale(res: xr.Dataset, divide_by_scale: bool = True) -> floa def shift_time_axis_by_irf_location( - plot_data: xr.DataArray, - irf_location: float | None, + plot_data: xr.DataArray, irf_location: float | None, *, _internal_call: bool = False ) -> xr.DataArray: """Shift ``plot_data`` 'time' axis by the position of the main ``irf``. @@ -318,6 +317,9 @@ def shift_time_axis_by_irf_location( Data to plot. irf_location : float | None Location of the ``irf``, if the value is None the original ``plot_data`` will be returned. + _internal_call : bool + This indicates internal use stripping away user help and silently skipping execution. + Defaults to False. Returns ------- @@ -333,7 +335,7 @@ def shift_time_axis_by_irf_location( -------- extract_irf_location """ - if irf_location is None: + if irf_location is None or ("time" not in plot_data.coords and _internal_call is True): return plot_data if "time" not in plot_data.coords: