Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 Fix crash when plotting spectral model result #200

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyglotaran_extras/plotting/plot_coherent_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def plot_coherent_artifact(
return fig, axes

irf_location = extract_irf_location(dataset, spectral, main_irf_nr)
irf_data = shift_time_axis_by_irf_location(dataset.coherent_artifact_response, irf_location)
irf_data = shift_time_axis_by_irf_location(
dataset.coherent_artifact_response, irf_location, _internal_call=True
)

irf_max = abs_max(irf_data, result_dims=("coherent_artifact_order"))
irfas_max = abs_max(
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def plot_data_overview(
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset, _stacklevel=3)
data = shift_time_axis_by_irf_location(dataset.data, irf_location)
data = shift_time_axis_by_irf_location(dataset.data, irf_location, _internal_call=True)

if len(not_single_element_dims(data)) == 1:
return _plot_single_trace(
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_doas.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def plot_doas(
oscillations = oscillations.sel(spectral=spectral, method="nearest")

oscillations = shift_time_axis_by_irf_location(
oscillations.sel(**osc_sel_kwargs), irf_location
oscillations.sel(**osc_sel_kwargs), irf_location, _internal_call=True
)
oscillations_spectra = dataset["damped_oscillation_associated_spectra"].sel(**osc_sel_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def plot_residual(

add_cycler_if_not_none(ax, cycler)
data = res.data if show_data else res.residual
data = shift_time_axis_by_irf_location(data, irf_location)
data = shift_time_axis_by_irf_location(data, irf_location, _internal_call=True)
title = "dataset" if show_data else "residual"
shape = np.array(data.shape)
# Handle different dimensionality of data
Expand Down
28 changes: 21 additions & 7 deletions pyglotaran_extras/plotting/plot_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def plot_svd(
cycler=cycler,
indices=range(nr_of_residual_svd_vectors),
show_legend=show_residual_svd_legend,
irf_location=irf_location,
)
plot_sv_residual(res, axes[0, 2], cycler=cycler)
add_svd_to_dataset(dataset=res, name="data")
Expand All @@ -97,6 +98,7 @@ def plot_svd(
cycler=cycler,
indices=range(nr_of_data_svd_vectors),
show_legend=show_data_svd_legend,
irf_location=irf_location,
)
plot_sv_data(res, axes[1, 2], cycler=cycler)

Expand Down Expand Up @@ -136,8 +138,7 @@ def plot_lsv_data(
"""
add_cycler_if_not_none(ax, cycler)
dLSV = res.data_left_singular_vectors # noqa: N806
dLSV = shift_time_axis_by_irf_location(dLSV, irf_location) # noqa: N806
_plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend)
_plot_svd_vectors(dLSV, indices, "left_singular_value_index", ax, show_legend, irf_location)
ax.set_title("data. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
Expand All @@ -150,6 +151,7 @@ def plot_rsv_data(
indices: Sequence[int] = range(4),
cycler: Cycler | None = PlotStyle().cycler,
show_legend: bool = True,
irf_location: float | None = None,
) -> None:
"""Plot right singular vectors (spectra) of the data matrix.

Expand All @@ -165,10 +167,13 @@ def plot_rsv_data(
Plot style cycler to use. Defaults to PlotStyle().cycler.
show_legend : bool
Whether or not to show the legend. Defaults to True.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.
"""
add_cycler_if_not_none(ax, cycler)
dRSV = res.data_right_singular_vectors # noqa: N806
_plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend)
_plot_svd_vectors(dRSV, indices, "right_singular_value_index", ax, show_legend, irf_location)
ax.set_title("data. RSV")


Expand Down Expand Up @@ -237,8 +242,7 @@ def plot_lsv_residual(
rLSV = res.weighted_residual_left_singular_vectors # noqa: N806
else:
rLSV = res.residual_left_singular_vectors # noqa: N806
rLSV = shift_time_axis_by_irf_location(rLSV, irf_location) # noqa: N806
_plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend)
_plot_svd_vectors(rLSV, indices, "left_singular_value_index", ax, show_legend, irf_location)
ax.set_title("res. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
Expand All @@ -251,6 +255,7 @@ def plot_rsv_residual(
indices: Sequence[int] = range(2),
cycler: Cycler | None = PlotStyle().cycler,
show_legend: bool = True,
irf_location: float | None = None,
) -> None:
"""Plot right singular vectors (spectra) of the residual matrix.

Expand All @@ -266,13 +271,16 @@ def plot_rsv_residual(
Plot style cycler to use. Defaults to PlotStyle().cycler.
show_legend : bool
Whether or not to show the legend. Defaults to True.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.
"""
add_cycler_if_not_none(ax, cycler)
if "weighted_residual_right_singular_vectors" in res:
rRSV = res.weighted_residual_right_singular_vectors # noqa: N806
else:
rRSV = res.residual_right_singular_vectors # noqa: N806
_plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend)
_plot_svd_vectors(rRSV, indices, "right_singular_value_index", ax, show_legend, irf_location)
ax.set_title("res. RSV")


Expand Down Expand Up @@ -312,6 +320,7 @@ def _plot_svd_vectors(
sv_index_dim: str,
ax: Axis,
show_legend: bool,
irf_location: float | None,
) -> None:
"""Plot SVD vectors with decreasing zorder on axis ``ax``.

Expand All @@ -327,6 +336,9 @@ def _plot_svd_vectors(
Axis to plot on.
show_legend : bool
Whether or not to show the legend.
irf_location : float | None
Location of the ``irf`` by which the time axis will get shifted. If it is None the time
axis will not be shifted. Defaults to None.

See Also
--------
Expand All @@ -336,7 +348,9 @@ def _plot_svd_vectors(
plot_rsv_residual
"""
max_index = len(getattr(vector_data, sv_index_dim))
values = vector_data.isel(**{sv_index_dim: indices[:max_index]})
values = shift_time_axis_by_irf_location(
vector_data.isel(**{sv_index_dim: indices[:max_index]}), irf_location, _internal_call=True
)
x_dim = vector_data.dims[1]
if x_dim == sv_index_dim:
values = values.T
Expand Down
8 changes: 5 additions & 3 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ def extract_dataset_scale(res: xr.Dataset, divide_by_scale: bool = True) -> floa


def shift_time_axis_by_irf_location(
plot_data: xr.DataArray,
irf_location: float | None,
plot_data: xr.DataArray, irf_location: float | None, *, _internal_call: bool = False
) -> xr.DataArray:
"""Shift ``plot_data`` 'time' axis by the position of the main ``irf``.

Expand All @@ -318,6 +317,9 @@ def shift_time_axis_by_irf_location(
Data to plot.
irf_location : float | None
Location of the ``irf``, if the value is None the original ``plot_data`` will be returned.
_internal_call : bool
This indicates internal use stripping away user help and silently skipping execution.
Defaults to False.

Returns
-------
Expand All @@ -333,7 +335,7 @@ def shift_time_axis_by_irf_location(
--------
extract_irf_location
"""
if irf_location is None:
if irf_location is None or ("time" not in plot_data.coords and _internal_call is True):
return plot_data

if "time" not in plot_data.coords:
Expand Down