Skip to content

Commit

Permalink
🩹 Fix find_axes inside use_plot_config causing TypeError with DataArr…
Browse files Browse the repository at this point in the history
…ay (#303)

* 🧪 Added test reproducing error in find_axes when a xr.DataArray is passed
* 🩹 Fix issue with xr.DataArray
* 👌 Endure Mapping of Axes is handled correctly by find_axes
* 🚧📚 Added change to changelog
  • Loading branch information
s-weigand authored Sep 28, 2024
1 parent 4d38e46 commit 29b37dc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## 0.7.4 (Unreleased)

- 🩹 Fix find_axes inside use_plot_config causing TypeError with DataArray (#303)

(changes-0_7_3)=

## 0.7.3 (2024-08-25)
Expand Down
7 changes: 6 additions & 1 deletion pyglotaran_extras/config/plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import cast

import numpy as np
import xarray as xr
from docstring_parser import parse as parse_docstring
from matplotlib.axes import Axes
from pydantic import BaseModel
Expand Down Expand Up @@ -479,12 +480,16 @@ def find_axes(
Axes
"""
for value in values:
if isinstance(value, str):
# This are iterables where we are sure that they can not contain `Axes` so we can skip them
# early
if isinstance(value, str | xr.Dataset | xr.DataArray):
continue
elif isinstance(value, Axes):
yield value
elif isinstance(value, np.ndarray):
yield from find_axes(value.flatten())
elif isinstance(value, Mapping):
yield from find_axes(value.values())
elif isinstance(value, Iterable):
yield from find_axes(value)

Expand Down
17 changes: 13 additions & 4 deletions tests/config/test_plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import matplotlib.pyplot as plt
import pytest
import xarray as xr
from jsonschema import ValidationError as SchemaValidationError
from jsonschema import validate
from pydantic import ValidationError as PydanticValidationError
Expand Down Expand Up @@ -466,15 +467,18 @@ def func(
def test_find_axes():
"""Get axes value from iterable of values."""

base_values = ["foo", True, 1.5]
data_array = xr.DataArray([[0, 1]], coords={"time": [0], "spectral": [500, 510]})
data_set = xr.Dataset({"data": data_array})

base_values = ["foo", True, 1.5, data_array, data_set]

assert generator_is_exhausted(find_axes(base_values)) is True

_, ax = plt.subplots()
single_ax_gen = find_axes([*base_values, ax])
dict_ax_gen = find_axes([*base_values, ax])

assert next(single_ax_gen) is ax
assert generator_is_exhausted(single_ax_gen) is True
assert next(dict_ax_gen) is ax
assert generator_is_exhausted(dict_ax_gen) is True

_, np_axes = plt.subplots(1, 2)

Expand All @@ -500,6 +504,11 @@ def test_find_axes():
assert next(multiple_axes_gen) is ax1
assert generator_is_exhausted(multiple_axes_gen) is True

dict_ax_gen = find_axes([*base_values, {"ax": ax}])

assert next(dict_ax_gen) is ax
assert generator_is_exhausted(dict_ax_gen) is True


def test_use_plot_config(mock_config: tuple[Config, dict[str, Any]]):
"""Config is applied to functions with the ``use_plot_config`` decorator."""
Expand Down

0 comments on commit 29b37dc

Please sign in to comment.