Skip to content

Commit

Permalink
🩹 Fix crash when plotting spectral model result (#200)
Browse files Browse the repository at this point in the history
This change fixes a high-level plot functions crash 
if the result data do not have a `time` axis (e.g. for a spectral model) and should be fully backwards compatible.
This is just a hotfix since it does not implement fully correct behavior for results from a spectral models (e.g. which plot to apply `linlog` to) and will be addressed in the future.

### Change summary

- [🩹 Fix crash when plotting spectral model
result](f4ea157)

### Checklist

- [ ] ✔️ Passing the tests (mandatory for all PR's)
  • Loading branch information
s-weigand authored Aug 24, 2023
1 parent 60b74fe commit da8418b
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 14 deletions.
4 changes: 3 additions & 1 deletion pyglotaran_extras/plotting/plot_coherent_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def plot_coherent_artifact(
return fig, axes

irf_location = extract_irf_location(dataset, spectral, main_irf_nr)
irf_data = shift_time_axis_by_irf_location(dataset.coherent_artifact_response, irf_location)
irf_data = shift_time_axis_by_irf_location(
dataset.coherent_artifact_response, irf_location, _internal_call=True
)

irf_max = abs_max(irf_data, result_dims=("coherent_artifact_order"))
irfas_max = abs_max(
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def plot_data_overview(
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset, _stacklevel=3)
data = shift_time_axis_by_irf_location(dataset.data, irf_location)
data = shift_time_axis_by_irf_location(dataset.data, irf_location, _internal_call=True)

if len(not_single_element_dims(data)) == 1:
return _plot_single_trace(
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_doas.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def plot_doas(
oscillations = oscillations.sel(spectral=spectral, method="nearest")

oscillations = shift_time_axis_by_irf_location(
oscillations.sel(**osc_sel_kwargs), irf_location
oscillations.sel(**osc_sel_kwargs), irf_location, _internal_call=True
)
oscillations_spectra = dataset["damped_oscillation_associated_spectra"].sel(**osc_sel_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def plot_residual(

add_cycler_if_not_none(ax, cycler)
data = res.data if show_data else res.residual
data = shift_time_axis_by_irf_location(data, irf_location)
data = shift_time_axis_by_irf_location(data, irf_location, _internal_call=True)
title = "dataset" if show_data else "residual"
shape = np.array(data.shape)
# Handle different dimensionality of data
Expand Down
28 changes: 21 additions & 7 deletions pyglotaran_extras/plotting/plot_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def plot_svd(
cycler=cycler,
indices=range(nr_of_residual_svd_vectors),
show_legend=show_residual_svd_legend,
irf_location=irf_location,
)
plot_sv_residual(res, axes[0, 2], cycler=cycler)
add_svd_to_dataset(dataset=res, name="data")
Expand All @@ -97,6 +98,7 @@ def plot_svd(
cycler=cycler,
indices=range(nr_of_data_svd_vectors),
show_legend=show_data_svd_legend,
irf_location=irf_location,
)
plot_sv_data(res, axes[1, 2], cycler=cycler)

Expand Down Expand Up @@ -136,8 +138,7 @@ def plot_lsv_data(
"""
add_cycler_if_not_none(ax, cycler)
dLSV = res.data_left_singular_vectors # noqa: N806
dLSV = shift_time_axis_by_irf_location(dLSV, irf_location) # noqa: N806
_plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend)
_plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend, irf_location)
ax.set_title("data. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
Expand All @@ -150,6 +151,7 @@ def plot_rsv_data(
indices: Sequence[int] = range(4),
cycler: Cycler | None = PlotStyle().cycler,
show_legend: bool = True,
irf_location: float | None = None,
) -> None:
"""Plot right singular vectors (spectra) of the data matrix.
Expand All @@ -165,10 +167,13 @@ def plot_rsv_data(
Plot style cycler to use. Defaults to PlotStyle().cycler.
show_legend : bool
Whether or not to show the legend. Defaults to True.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.
"""
add_cycler_if_not_none(ax, cycler)
dRSV = res.data_right_singular_vectors # noqa: N806
_plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend)
_plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend, irf_location)
ax.set_title("data. RSV")


Expand Down Expand Up @@ -237,8 +242,7 @@ def plot_lsv_residual(
rLSV = res.weighted_residual_left_singular_vectors # noqa: N806
else:
rLSV = res.residual_left_singular_vectors # noqa: N806
rLSV = shift_time_axis_by_irf_location(rLSV, irf_location) # noqa: N806
_plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend)
_plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend, irf_location)
ax.set_title("res. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
Expand All @@ -251,6 +255,7 @@ def plot_rsv_residual(
indices: Sequence[int] = range(2),
cycler: Cycler | None = PlotStyle().cycler,
show_legend: bool = True,
irf_location: float | None = None,
) -> None:
"""Plot right singular vectors (spectra) of the residual matrix.
Expand All @@ -266,13 +271,16 @@ def plot_rsv_residual(
Plot style cycler to use. Defaults to PlotStyle().cycler.
show_legend : bool
Whether or not to show the legend. Defaults to True.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.
"""
add_cycler_if_not_none(ax, cycler)
if "weighted_residual_right_singular_vectors" in res:
rRSV = res.weighted_residual_right_singular_vectors # noqa: N806
else:
rRSV = res.residual_right_singular_vectors # noqa: N806
_plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend)
_plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend, irf_location)
ax.set_title("res. RSV")


Expand Down Expand Up @@ -312,6 +320,7 @@ def _plot_svd_vectors(
sv_index_dim: str,
ax: Axis,
show_legend: bool,
irf_location: float | None,
) -> None:
"""Plot SVD vectors with decreasing zorder on axis ``ax``.
Expand All @@ -327,6 +336,9 @@ def _plot_svd_vectors(
Axis to plot on.
show_legend : bool
Whether or not to show the legend.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.
See Also
--------
Expand All @@ -336,7 +348,9 @@ def _plot_svd_vectors(
plot_rsv_residual
"""
max_index = len(getattr(vector_data, sv_index_dim))
values = vector_data.isel(**{sv_index_dim: indices[:max_index]})
values = shift_time_axis_by_irf_location(
vector_data.isel(**{sv_index_dim: indices[:max_index]}), irf_location, _internal_call=True
)
x_dim = vector_data.dims[1]
if x_dim == sv_index_dim:
values = values.T
Expand Down
8 changes: 5 additions & 3 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ def extract_dataset_scale(res: xr.Dataset, divide_by_scale: bool = True) -> floa


def shift_time_axis_by_irf_location(
plot_data: xr.DataArray,
irf_location: float | None,
plot_data: xr.DataArray, irf_location: float | None, *, _internal_call: bool = False
) -> xr.DataArray:
"""Shift ``plot_data`` 'time' axis by the position of the main ``irf``.
Expand All @@ -318,6 +317,9 @@ def shift_time_axis_by_irf_location(
Data to plot.
irf_location : float | None
Location of the ``irf``, if the value is None the original ``plot_data`` will be returned.
_internal_call : bool
This indicates internal use stripping away user help and silently skipping execution.
Defaults to False.
Returns
-------
Expand All @@ -333,7 +335,7 @@ def shift_time_axis_by_irf_location(
--------
extract_irf_location
"""
if irf_location is None:
if irf_location is None or ("time" not in plot_data.coords and _internal_call is True):
return plot_data

if "time" not in plot_data.coords:
Expand Down

0 comments on commit da8418b

Please sign in to comment.