diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 00000000..245a7e08 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,48 @@ +name: "Run Examples" + +on: + push: + pull_request: + workflow_dispatch: + +jobs: + run-examples: + name: "Run Example: " + runs-on: ubuntu-latest + strategy: + matrix: + example_name: + [ + quick-start, + fluorescence, + transient-absorption, + transient-absorption-two-datasets, + spectral-constraints, + spectral-guidance, + two-datasets, + sim-3d-disp, + sim-3d-nodisp, + sim-3d-weight, + sim-6d-disp, + ] + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install pyglotaran-extras + run: | + pip install wheel + pip install . + - name: ${{ matrix.example_name }} + id: example-run + uses: glotaran/pyglotaran-examples@main + with: + example_name: ${{ matrix.example_name }} + install_extras: false + - name: Upload Example Plots Artifact + uses: actions/upload-artifact@v2 + with: + name: example-plots + path: ${{ steps.example-run.outputs.plots-path }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7090b776..7ce9846c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: - id: absolufy-imports - repo: https://github.com/asottile/pyupgrade - rev: v2.26.0 + rev: v2.29.0 hooks: - id: pyupgrade args: [--py38-plus] @@ -36,13 +36,13 @@ repos: minimum_pre_commit_version: 2.9.0 - repo: https://github.com/asottile/yesqa - rev: v1.2.3 + rev: v1.3.0 hooks: - id: yesqa additional_dependencies: [flake8-docstrings] - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.17.0 + rev: v1.18.0 hooks: - id: setup-cfg-fmt @@ -76,7 +76,7 @@ repos: # Linters - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v0.910-1 hooks: - id: mypy exclude: ^docs @@ -90,7 +90,7 @@ repos: # pass_filenames: false - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 4.0.1 hooks: - id: flake8 args: @@ -111,5 +111,6 @@ repos: rev: v2.1.0 hooks: - id: codespell - files: ".py|.rst" - args: [-L pyglotaran] + types: [file] + types_or: [python, pyi, markdown, rst, jupyter] + args: ["--ignore-words-list=doas"] diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index e19434e2..1e10cd8a 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -1 +1,19 @@ -__version__ = "0.3.3" +from pyglotaran_extras.io.load_data import load_data +from pyglotaran_extras.io.setup_case_study import setup_case_study +from pyglotaran_extras.plotting.plot_data import plot_data_overview +from pyglotaran_extras.plotting.plot_overview import plot_overview +from pyglotaran_extras.plotting.plot_overview import plot_simple_overview +from pyglotaran_extras.plotting.plot_traces import plot_fitted_traces +from pyglotaran_extras.plotting.plot_traces import select_plot_wavelengths + +__all__ = [ + "load_data", + "setup_case_study", + "plot_data_overview", + "plot_overview", + "plot_simple_overview", + "plot_fitted_traces", + "select_plot_wavelengths", +] + +__version__ = "0.5.0rc1" diff --git a/pyglotaran_extras/deprecation/__init__.py b/pyglotaran_extras/deprecation/__init__.py new file mode 100644 index 00000000..fd1a4ef8 --- /dev/null +++ b/pyglotaran_extras/deprecation/__init__.py @@ -0,0 +1,4 @@ +"""Module containing deprecation functionality.""" +from pyglotaran_extras.deprecation.deprecation_utils import warn_deprecated + +__all__ = ["warn_deprecated"] diff --git a/pyglotaran_extras/deprecation/deprecation_utils.py b/pyglotaran_extras/deprecation/deprecation_utils.py new file mode 100644 index 00000000..a840d941 --- /dev/null +++ b/pyglotaran_extras/deprecation/deprecation_utils.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from importlib.metadata import distribution +from warnings import warn + +FIG_ONLY_WARNING = ( + "In the future plot functions which create figures, will return a tuple " + "of Figure AND the Axes. Please set ``figure_only=False`` and adjust your code.\n" + "This usage will be an error in version: 0.7.0." +) + + +class OverDueDeprecation(Exception): + """Error thrown when a deprecation should have been removed. + + See Also + -------- + deprecate + warn_deprecated + deprecate_module_attribute + deprecate_submodule + deprecate_dict_entry + """ + + +class PyglotaranExtrasApiDeprecationWarning(UserWarning): + """Warning to give users about API changes. + + See Also + -------- + warn_deprecated + """ + + +def pyglotaran_extras_version() -> str: + """Version of the distribution. + + This is basically the same as ``pyglotaran_extras.__version__`` but independent + from pyglotaran_extras. + This way all of the deprecation functionality can be used even in + ``pyglotaran_extras.__init__.py`` without moving the import below the definition of + ``__version__`` or causeing a circular import issue. + + Returns + ------- + str + The version string. + """ + return distribution("pyglotaran-extras").version + + +def parse_version(version_str: str) -> tuple[int, int, int]: + """Parse version string to tuple of three ints for comparison. + + Parameters + ---------- + version_str : str + Fully qualified version string of the form 'major.minor.patch'. + + Returns + ------- + tuple[int, int, int] + Version as tuple. + + Raises + ------ + ValueError + If ``version_str`` has less that three elements separated by ``.``. + ValueError + If ``version_str`` 's first three elements can not be casted to int. + """ + error_message = ( + "version_str needs to be a fully qualified version consisting of " + f"int parts (e.g. '0.0.1'), got {version_str!r}" + ) + split_version = version_str.partition("-")[0].split(".") + if len(split_version) < 3: + raise ValueError(error_message) + try: + return tuple( + map(int, (*split_version[:2], split_version[2].partition("rc")[0])) + ) # type:ignore[return-value] + except ValueError: + raise ValueError(error_message) + + +def check_overdue(deprecated_qual_name_usage: str, to_be_removed_in_version: str) -> None: + """Check if a deprecation is overdue for removal. + + Parameters + ---------- + deprecated_qual_name_usage : str + Old usage with fully qualified name e.g.: + ``'glotaran.read_model_from_yaml(model_yml_str)'`` + to_be_removed_in_version : str + Version the support for this usage will be removed. + + Raises + ------ + OverDueDeprecation + If the current version is greater or equal to ``to_be_removed_in_version``. + """ + if ( + parse_version(pyglotaran_extras_version()) >= parse_version(to_be_removed_in_version) + and "dev" not in pyglotaran_extras_version() + ): + raise OverDueDeprecation( + f"Support for {deprecated_qual_name_usage.partition('(')[0]!r} was " + f"supposed to be dropped in version: {to_be_removed_in_version!r}.\n" + f"Current version is: {pyglotaran_extras_version()!r}" + ) + + +def warn_deprecated( + *, + deprecated_qual_name_usage: str, + new_qual_name_usage: str, + to_be_removed_in_version: str, + stacklevel: int = 2, +) -> None: + """Raise deprecation warning with change information. + + The change information are old / new usage information and end of support version. + + Parameters + ---------- + deprecated_qual_name_usage : str + Old usage with fully qualified name e.g.: + ``'glotaran.read_model_from_yaml(model_yml_str)'`` + new_qual_name_usage : str + New usage as fully qualified name e.g.: + ``'glotaran.io.load_model(model_yml_str, format_name="yml_str")'`` + to_be_removed_in_version : str + Version the support for this usage will be removed. + + stacklevel: int + Stack at which the warning should be shown as raise. Default: 2 + + Raises + ------ + OverDueDeprecation + If the current version is greater or equal to ``to_be_removed_in_version``. + """ + check_overdue(deprecated_qual_name_usage, to_be_removed_in_version) + warn( + PyglotaranExtrasApiDeprecationWarning( + f"Usage of {deprecated_qual_name_usage!r} was deprecated, " + f"use {new_qual_name_usage!r} instead.\n" + f"This usage will be an error in version: {to_be_removed_in_version!r}." + ), + stacklevel=stacklevel, + ) diff --git a/pyglotaran_extras/io/__init__.py b/pyglotaran_extras/io/__init__.py index d6acd389..ee680461 100644 --- a/pyglotaran_extras/io/__init__.py +++ b/pyglotaran_extras/io/__init__.py @@ -1,2 +1,5 @@ """ pyglotaran-extras io package """ -from pyglotaran_extras.io.boilerplate import setup_case_study +from pyglotaran_extras.io.load_data import load_data +from pyglotaran_extras.io.setup_case_study import setup_case_study + +__all__ = ["setup_case_study", "load_data"] diff --git a/pyglotaran_extras/io/boilerplate.py b/pyglotaran_extras/io/boilerplate.py index 54a79361..cba7b695 100644 --- a/pyglotaran_extras/io/boilerplate.py +++ b/pyglotaran_extras/io/boilerplate.py @@ -1,81 +1,12 @@ -from __future__ import annotations - -import inspect -from os import PathLike -from pathlib import Path - - -def setup_case_study( - output_folder_name: str = "pyglotaran_results", - results_folder_root: None | str | PathLike[str] = None, -) -> tuple[Path, Path]: - """Convenience function to to quickly get folders for a case study. - - This is an execution environment independent (works in python script files - and notebooks, independent of where the python runtime was called from) - way to get the folder the analysis code resides in and also creates the - ``results_folder`` in case it didn't exist before. - - Parameters - ---------- - output_folder_name : str - Name of the base folder for the results., by default "pyglotaran_results" - results_folder_root : None - Parent folder the ``output_folder_name`` should be put in - , by default None which results in the users Home folder being used - - Returns - ------- - tuple[Path, Path] - results_folder, script_folder: - - results_folder: - Folder to be used to save results in of the pattern - (``results_folder_root`` / ``output_folder_name`` / ``analysis_folder.parent``). - analysis_folder: - Folder the script or Notebook resides in. - """ - analysis_folder = get_script_dir(nesting=1) - print(f"Setting up case study for folder: {analysis_folder}") - if results_folder_root is None: - results_folder_root = Path.home() / output_folder_name - else: - results_folder_root = Path(str(results_folder_root)) - script_folder_rel = analysis_folder.relative_to(analysis_folder.parent) - results_folder = (results_folder_root / script_folder_rel).resolve() - results_folder.mkdir(parents=True, exist_ok=True) - print(f"Results will be saved in: {results_folder}") - return results_folder, analysis_folder.resolve() - - -def get_script_dir(*, nesting: int = 0) -> Path: - """Gets parent folder a script is executed in. - - This is a helper function for cross compatibility with jupyter notebooks. - In notebooks the global ``__file__`` variable isn't set, thus we need different - means to get the folder a script is defined in, which doesn't change with the - current working director the ``python interpreter`` was called from. - - Parameters - ---------- - nesting : int - Number to go up in the call stack to get to the initially calling function. - This is only needed for library code and not for user code. - , by default 0 (direct call) - - Returns - ------- - Path - Path to the folder the script was resides in. - - See Also - -------- - setup_case_study - """ - calling_frame = inspect.stack()[nesting + 1].frame - file_var = calling_frame.f_globals.get("__file__", ".") - file_path = Path(file_var).resolve() - if file_var == ".": # pragma: no cover - return file_path - else: - return file_path.parent +"""Deprecated module.""" +from pyglotaran_extras.deprecation import warn_deprecated +from pyglotaran_extras.io.setup_case_study import setup_case_study + +__all__ = ["setup_case_study"] + +warn_deprecated( + deprecated_qual_name_usage="pyglotaran_extras.io.boilerplate", + new_qual_name_usage="pyglotaran_extras.io", + to_be_removed_in_version="0.7.0", + stacklevel=3, +) diff --git a/pyglotaran_extras/io/setup_case_study.py b/pyglotaran_extras/io/setup_case_study.py new file mode 100644 index 00000000..6225bb72 --- /dev/null +++ b/pyglotaran_extras/io/setup_case_study.py @@ -0,0 +1,82 @@ +"""Module contain function to initialize a case study.""" +from __future__ import annotations + +import inspect +from os import PathLike +from pathlib import Path + + +def setup_case_study( + output_folder_name: str = "pyglotaran_results", + results_folder_root: None | str | PathLike[str] = None, +) -> tuple[Path, Path]: + """Convenience function to to quickly get folders for a case study. + + This is an execution environment independent (works in python script files + and notebooks, independent of where the python runtime was called from) + way to get the folder the analysis code resides in and also creates the + ``results_folder`` in case it didn't exist before. + + Parameters + ---------- + output_folder_name : str + Name of the base folder for the results., by default "pyglotaran_results" + results_folder_root : None + Parent folder the ``output_folder_name`` should be put in + , by default None which results in the users Home folder being used + + Returns + ------- + tuple[Path, Path] + results_folder, script_folder: + + results_folder: + Folder to be used to save results in of the pattern + (``results_folder_root`` / ``output_folder_name`` / ``analysis_folder.parent``). + analysis_folder: + Folder the script or Notebook resides in. + """ + analysis_folder = get_script_dir(nesting=1) + print(f"Setting up case study for folder: {analysis_folder}") + if results_folder_root is None: + results_folder_root = Path.home() / output_folder_name + else: + results_folder_root = Path(str(results_folder_root)) + script_folder_rel = analysis_folder.relative_to(analysis_folder.parent) + results_folder = (results_folder_root / script_folder_rel).resolve() + results_folder.mkdir(parents=True, exist_ok=True) + print(f"Results will be saved in: {results_folder}") + return results_folder, analysis_folder.resolve() + + +def get_script_dir(*, nesting: int = 0) -> Path: + """Gets parent folder a script is executed in. + + This is a helper function for cross compatibility with jupyter notebooks. + In notebooks the global ``__file__`` variable isn't set, thus we need different + means to get the folder a script is defined in, which doesn't change with the + current working director the ``python interpreter`` was called from. + + Parameters + ---------- + nesting : int + Number to go up in the call stack to get to the initially calling function. + This is only needed for library code and not for user code. + , by default 0 (direct call) + + Returns + ------- + Path + Path to the folder the script was resides in. + + See Also + -------- + setup_case_study + """ + calling_frame = inspect.stack()[nesting + 1].frame + file_var = calling_frame.f_globals.get("__file__", ".") + file_path = Path(file_var).resolve() + if file_var == ".": # pragma: no cover + return file_path + else: + return file_path.parent diff --git a/pyglotaran_extras/plotting/__init__.py b/pyglotaran_extras/plotting/__init__.py index 44d3f36a..56a7dde5 100644 --- a/pyglotaran_extras/plotting/__init__.py +++ b/pyglotaran_extras/plotting/__init__.py @@ -1 +1,6 @@ """ pyglotaran-extras plotting package """ +from pyglotaran_extras.plotting.plot_data import plot_data_overview +from pyglotaran_extras.plotting.plot_overview import plot_overview +from pyglotaran_extras.plotting.plot_overview import plot_simple_overview + +__all__ = ["plot_data_overview", "plot_overview", "plot_simple_overview"] diff --git a/pyglotaran_extras/plotting/data.py b/pyglotaran_extras/plotting/data.py index ed3f6118..59d06798 100644 --- a/pyglotaran_extras/plotting/data.py +++ b/pyglotaran_extras/plotting/data.py @@ -1,64 +1,13 @@ -from __future__ import annotations +"""Deprecated module.""" -from typing import TYPE_CHECKING +from pyglotaran_extras.deprecation import warn_deprecated +from pyglotaran_extras.plotting.plot_data import plot_data_overview -import matplotlib.pyplot as plt +__all__ = ["plot_data_overview"] -from pyglotaran_extras.plotting.plot_svd import plot_lsv_data -from pyglotaran_extras.plotting.plot_svd import plot_rsv_data -from pyglotaran_extras.plotting.plot_svd import plot_sv_data -from pyglotaran_extras.plotting.utils import select_plot_wavelengths - -__all__ = ["select_plot_wavelengths", "plot_data_overview"] - -if TYPE_CHECKING: - from matplotlib.figure import Figure - from matplotlib.pyplot import Axes - from xarray import Dataset - - -def plot_data_overview( - dataset: Dataset, - title="Data overview", - linlog: bool = False, - linthresh: float = 1, - figsize: tuple[int, int] = (30, 15), -) -> tuple[Figure, Axes]: - """Plot data as filled contour plot and SVD components. - - Parameters - ---------- - dataset : Dataset - Dataset containing data and SVD of the data. - title : str, optional - Title to add to the figure., by default "Data overview" - linlog : bool, optional - Whether to use 'symlog' scale or not, by default False - linthresh : float, optional - A single float which defines the range (-x, x), within which the plot is linear. - This avoids having the plot go to infinity around zero., by default 1 - - Returns - ------- - tuple[Figure, Axes] - Figure and axes which can then be refined by the user. - """ - fig = plt.figure(figsize=figsize) - data_ax = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig) - lsv_ax = plt.subplot2grid((4, 3), (3, 0), fig=fig) - sv_ax = plt.subplot2grid((4, 3), (3, 1), fig=fig) - rsv_ax = plt.subplot2grid((4, 3), (3, 2), fig=fig) - - if len(dataset.data.time) > 1: - dataset.data.plot(x="time", ax=data_ax, center=False) - else: - dataset.data.plot(ax=data_ax) - plot_lsv_data(dataset, lsv_ax) - plot_sv_data(dataset, sv_ax) - plot_rsv_data(dataset, rsv_ax) - fig.suptitle(title, fontsize=16) - fig.tight_layout() - - if linlog: - data_ax.set_xscale("symlog", linthresh=linthresh) - return fig, (data_ax, lsv_ax, sv_ax, rsv_ax) +warn_deprecated( + deprecated_qual_name_usage="pyglotaran_extras.plotting.data", + new_qual_name_usage="pyglotaran_extras.plotting.plot_data", + to_be_removed_in_version="0.7.0", + stacklevel=3, +) diff --git a/pyglotaran_extras/plotting/plot_concentrations.py b/pyglotaran_extras/plotting/plot_concentrations.py index a5f4ae81..e67925ca 100644 --- a/pyglotaran_extras/plotting/plot_concentrations.py +++ b/pyglotaran_extras/plotting/plot_concentrations.py @@ -2,33 +2,33 @@ from typing import TYPE_CHECKING -import matplotlib.pyplot as plt - from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import get_shifted_traces if TYPE_CHECKING: import xarray as xr - from matplotlib.pyplot import Axes + from cycler import Cycler + from matplotlib.axis import Axis def plot_concentrations( res: xr.Dataset, - ax: Axes, + ax: Axis, center_λ: float | None, linlog: bool = False, linthresh: float = 1, linscale: float = 1, main_irf_nr: int = 0, + cycler: Cycler = PlotStyle().cycler, ) -> None: - """Plot traces on the given axis ``ax`` + """Plot traces on the given axis ``ax``. Parameters ---------- res: xr.Dataset Result dataset from a pyglotaran optimization. - ax: Axes - Axes to plot the traces on + ax: Axis + Axis to plot the traces on center_λ: float | None Center wavelength (λ in nm) linlog: bool @@ -46,14 +46,15 @@ def plot_concentrations( main_irf_nr: int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks , by default 0 + cycler : Cycler + Plot style cycler to use., by default PlotStyle().data_cycler_solid See Also -------- get_shifted_traces """ + ax.set_prop_cycle(cycler) traces = get_shifted_traces(res, center_λ, main_irf_nr) - plot_style = PlotStyle() - plt.rc("axes", prop_cycle=plot_style.cycler) if "spectral" in traces.coords: traces.sel(spectral=center_λ, method="nearest").plot.line(x="time", ax=ax) diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py new file mode 100644 index 00000000..59b64ea0 --- /dev/null +++ b/pyglotaran_extras/plotting/plot_data.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt + +from pyglotaran_extras.plotting.plot_svd import plot_lsv_data +from pyglotaran_extras.plotting.plot_svd import plot_rsv_data +from pyglotaran_extras.plotting.plot_svd import plot_sv_data + +__all__ = ["plot_data_overview"] + +if TYPE_CHECKING: + from typing import cast + + import xarray as xr + from matplotlib.axis import Axis + from matplotlib.figure import Figure + from matplotlib.pyplot import Axes + + +def plot_data_overview( + dataset: xr.Dataset, + title: str = "Data overview", + linlog: bool = False, + linthresh: float = 1, + figsize: tuple[int, int] = (30, 15), +) -> tuple[Figure, Axes]: + """Plot data as filled contour plot and SVD components. + + Parameters + ---------- + dataset : Dataset + Dataset containing data and SVD of the data. + title : str, optional + Title to add to the figure., by default "Data overview" + linlog : bool, optional + Whether to use 'symlog' scale or not, by default False + linthresh : float, optional + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero., by default 1 + + Returns + ------- + tuple[Figure, Axes] + Figure and axes which can then be refined by the user. + """ + fig = plt.figure(figsize=figsize) + data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)) + lsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 0), fig=fig)) + sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig)) + rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig)) + + if len(dataset.data.time) > 1: + dataset.data.plot(x="time", ax=data_ax, center=False) + else: + dataset.data.plot(ax=data_ax) + plot_lsv_data(dataset, lsv_ax) + plot_sv_data(dataset, sv_ax) + plot_rsv_data(dataset, rsv_ax) + fig.suptitle(title, fontsize=16) + fig.tight_layout() + + if linlog: + data_ax.set_xscale("symlog", linthresh=linthresh) + return fig, (data_ax, lsv_ax, sv_ax, rsv_ax) diff --git a/pyglotaran_extras/plotting/plot_doas.py b/pyglotaran_extras/plotting/plot_doas.py index d8188a35..b3ba7d23 100644 --- a/pyglotaran_extras/plotting/plot_doas.py +++ b/pyglotaran_extras/plotting/plot_doas.py @@ -1,45 +1,93 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from warnings import warn + import matplotlib.pyplot as plt +from pyglotaran_extras.deprecation.deprecation_utils import FIG_ONLY_WARNING +from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning from pyglotaran_extras.io.load_data import load_data +from pyglotaran_extras.plotting.style import PlotStyle + +if TYPE_CHECKING: + from cycler import Cycler + from matplotlib.figure import Figure + from matplotlib.pyplot import Axes + + from pyglotaran_extras.types import DatasetConvertible -def plot_doas(path): - dataset = load_data(path) +def plot_doas( + result: DatasetConvertible, + figsize: tuple[int, int] = (25, 25), + cycler: Cycler = PlotStyle().cycler, + figure_only: bool = True, +) -> Figure | tuple[Figure, Axes]: + """Plot damped oscillations (DOAS). + + Parameters + ---------- + result: DatasetConvertible + Result from a pyglotaran optimization as dataset, Path or Result object. + figsize : tuple[int, int] + Size of the figure (N, M) in inches., by default (18, 16) + cycler : Cycler + Plot style cycler to use., by default PlotStyle().cycler + figure_only: bool + Whether or not to only return the figure. + This is a deprecation helper argument to transition to a consistent return value + consisting of the :class:`Figure` and the :class:`Axes`, by default True + + Returns + ------- + Figure|tuple[Figure, Axes] + If ``figure_only`` is True, Figure object which contains the plots (deprecated). + If ``figure_only`` is False, Figure object which contains the plots and the Axes. + """ + dataset = load_data(result) # Create M x N plotting grid M = 6 N = 3 - fig, ax = plt.subplots(M, N, figsize=(25, 25)) + fig, axes = plt.subplots(M, N, figsize=figsize) + + for ax in axes.flatten(): + ax.set_prop_cycle(cycler) # Plot data - dataset.species_associated_spectra.plot.line(x="spectral", ax=ax[0, 0]) - dataset.decay_associated_spectra.plot.line(x="spectral", ax=ax[0, 1]) + dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0]) + dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1]) if "spectral" in dataset.species_concentration.coords: - dataset.species_concentration.isel(spectral=0).plot.line(x="time", ax=ax[1, 0]) + dataset.species_concentration.isel(spectral=0).plot.line(x="time", ax=axes[1, 0]) else: - dataset.species_concentration.plot.line(x="time", ax=ax[1, 0]) - ax[1, 0].set_xscale("symlog", linthreshx=1) + dataset.species_concentration.plot.line(x="time", ax=axes[1, 0]) + axes[1, 0].set_xscale("symlog", linthreshx=1) if "dampened_oscillation_associated_spectra" in dataset: dataset.dampened_oscillation_cos.isel(spectral=0).sel(time=slice(-1, 10)).plot.line( - x="time", ax=ax[1, 1] + x="time", ax=axes[1, 1] ) - dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral", ax=ax[2, 0]) - dataset.dampened_oscillation_phase.plot.line(x="spectral", ax=ax[2, 1]) + dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral", ax=axes[2, 0]) + dataset.dampened_oscillation_phase.plot.line(x="spectral", ax=axes[2, 1]) - dataset.residual_left_singular_vectors.isel(left_singular_value_index=0).plot(ax=ax[0, 2]) - dataset.residual_singular_values.plot.line("ro-", yscale="log", ax=ax[1, 2]) - dataset.residual_right_singular_vectors.isel(right_singular_value_index=0).plot(ax=ax[2, 2]) + dataset.residual_left_singular_vectors.isel(left_singular_value_index=0).plot(ax=axes[0, 2]) + dataset.residual_singular_values.plot.line("ro-", yscale="log", ax=axes[1, 2]) + dataset.residual_right_singular_vectors.isel(right_singular_value_index=0).plot(ax=axes[2, 2]) interval = int(dataset.spectral.size / 11) for i in range(0): - axi = ax[i % 3, int(i / 3) + 3] + axi = axes[i % 3, int(i / 3) + 3] index = (i + 1) * interval dataset.data.isel(spectral=index).plot(ax=axi) dataset.residual.isel(spectral=index).plot(ax=axi) dataset.fitted_data.isel(spectral=index).plot(ax=axi) plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0) - return fig + if figure_only is True: + warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) + return fig + else: + return fig, axes diff --git a/pyglotaran_extras/plotting/plot_overview.py b/pyglotaran_extras/plotting/plot_overview.py index 29c2bdb7..c5b5297a 100644 --- a/pyglotaran_extras/plotting/plot_overview.py +++ b/pyglotaran_extras/plotting/plot_overview.py @@ -2,10 +2,12 @@ from pathlib import Path from typing import TYPE_CHECKING +from warnings import warn import matplotlib.pyplot as plt -import xarray as xr +from pyglotaran_extras.deprecation.deprecation_utils import FIG_ONLY_WARNING +from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.plot_concentrations import plot_concentrations from pyglotaran_extras.plotting.plot_residual import plot_residual @@ -14,24 +16,30 @@ from pyglotaran_extras.plotting.style import PlotStyle if TYPE_CHECKING: - from glotaran.project import Result + from cycler import Cycler from matplotlib.figure import Figure + from matplotlib.pyplot import Axes + + from pyglotaran_extras.types import DatasetConvertible def plot_overview( - result: xr.Dataset | Path | Result, + result: DatasetConvertible, center_λ: float | None = None, linlog: bool = True, linthresh: float = 1, linscale: float = 1, show_data: bool = False, main_irf_nr: int = 0, -) -> Figure: + figsize: tuple[int, int] = (18, 16), + cycler: Cycler = PlotStyle().cycler, + figure_only: bool = True, +) -> Figure | tuple[Figure, Axes]: """Plot overview of the optimization result. Parameters ---------- - result : xr.Dataset | Path | Result + result: DatasetConvertible Result from a pyglotaran optimization as dataset, Path or Result object. center_λ: float | None Center wavelength (λ in nm) @@ -47,16 +55,25 @@ def plot_overview( For example, when linscale == 1.0 (the default), the space used for the positive and negative halves of the linear range will be equal to one decade in the logarithmic range., by default 1 - show_data : bool + show_data: bool Whether to show the input data or residual, by default False main_irf_nr: int Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks , by default 0 + figsize : tuple[int, int] + Size of the figure (N, M) in inches., by default (18, 16) + cycler : Cycler + Plot style cycler to use., by default PlotStyle().cycler + figure_only: bool + Whether or not to only return the figure. + This is a deprecation helper argument to transition to a consistent return value + consisting of the :class:`Figure` and the :class:`Axes`, by default True Returns ------- - Figure - Figure object which contains the plots. + Figure|tuple[Figure, Axes] + If ``figure_only`` is True, Figure object which contains the plots (deprecated). + If ``figure_only`` is False, Figure object which contains the plots and the Axes. """ res = load_data(result) @@ -64,10 +81,7 @@ def plot_overview( # Plot dimensions M = 4 N = 3 - fig, ax = plt.subplots(M, N, figsize=(18, 16), constrained_layout=True) - - plot_style = PlotStyle() - plt.rc("axes", prop_cycle=plot_style.cycler) + fig, axes = plt.subplots(M, N, figsize=figsize, constrained_layout=True) if center_λ is None: # center wavelength (λ in nm) center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2)) @@ -75,63 +89,104 @@ def plot_overview( # First and second row: concentrations - SAS/EAS - DAS plot_concentrations( res, - ax[0, 0], + axes[0, 0], center_λ, linlog=linlog, linthresh=linthresh, linscale=linscale, main_irf_nr=main_irf_nr, + cycler=cycler, + ) + plot_spectra(res, axes[0:2, 1:3], cycler=cycler) + plot_svd(res, axes[2:4, 0:3], linlog=linlog, linthresh=linthresh, cycler=cycler) + plot_residual( + res, axes[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data, cycler=cycler ) - plot_spectra(res, ax[0:2, 1:3]) - plot_svd(res, ax[2:4, 0:3], linlog=linlog, linthresh=linthresh) - plot_residual(res, ax[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data) - plot_style.set_default_colors() - plot_style.set_default_fontsize() - plt.rc("axes", prop_cycle=plot_style.cycler) # plt.tight_layout(pad=3, w_pad=4.0, h_pad=4.0) - return fig + if figure_only is True: + warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) + return fig + else: + return fig, axes + + +def plot_simple_overview( + result: DatasetConvertible, + title: str | None = None, + figsize: tuple[int, int] = (12, 6), + cycler: Cycler = PlotStyle().cycler, + figure_only: bool = True, +) -> Figure | tuple[Figure, Axes]: + """Simple plotting function . + + Parameters + ---------- + result: DatasetConvertible + Result from a pyglotaran optimization as dataset, Path or Result object. + title: str | None + Title of the figure., by default None + figsize : tuple[int, int] + Size of the figure (N, M) in inches., by default (18, 16) + cycler : Cycler + Plot style cycler to use., by default PlotStyle().cycler + figure_only: bool + Whether or not to only return the figure. + This is a deprecation helper argument to transition to a consistent return value + consisting of the :class:`Figure` and the :class:`Axes`, by default True + Returns + ------- + Figure|tuple[Figure, Axes] + If ``figure_only`` is True, Figure object which contains the plots (deprecated). + If ``figure_only`` is False, Figure object which contains the plots and the Axes. + """ + res = load_data(result) -def plot_simple_overview(res, title=None): - """simple plotting function derived from code from pyglotaran_extras""" - fig, ax = plt.subplots(2, 3, figsize=(12, 6), constrained_layout=True) + fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True) + for ax in axes.flatten(): + ax.set_prop_cycle(cycler) if title: fig.suptitle(title, fontsize=16) sas = res.species_associated_spectra traces = res.species_concentration if "spectral" in traces.coords: traces.sel(spectral=res.spectral.values[0], method="nearest").plot.line( - x="time", ax=ax[0, 0] + x="time", ax=axes[0, 0] ) else: - traces.plot.line(x="time", ax=ax[0, 0]) - sas.plot.line(x="spectral", ax=ax[0, 1]) + traces.plot.line(x="time", ax=axes[0, 0]) + sas.plot.line(x="spectral", ax=axes[0, 1]) rLSV = res.residual_left_singular_vectors - rLSV.isel(left_singular_value_index=range(min(2, len(rLSV)))).plot.line(x="time", ax=ax[1, 0]) + rLSV.isel(left_singular_value_index=range(min(2, len(rLSV)))).plot.line( + x="time", ax=axes[1, 0] + ) - ax[1, 0].set_title("res. LSV") + axes[1, 0].set_title("res. LSV") rRSV = res.residual_right_singular_vectors rRSV.isel(right_singular_value_index=range(min(2, len(rRSV)))).plot.line( - x="spectral", ax=ax[1, 1] + x="spectral", ax=axes[1, 1] ) - ax[1, 1].set_title("res. RSV") - res.data.plot(x="time", ax=ax[0, 2]) - ax[0, 2].set_title("data") - res.residual.plot(x="time", ax=ax[1, 2]) - ax[1, 2].set_title("residual") - plt.show(block=False) - return fig + axes[1, 1].set_title("res. RSV") + res.data.plot(x="time", ax=axes[0, 2]) + axes[0, 2].set_title("data") + res.residual.plot(x="time", ax=axes[1, 2]) + axes[1, 2].set_title("residual") + if figure_only is True: + warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) + return fig + else: + return fig, axes if __name__ == "__main__": import sys result_path = Path(sys.argv[1]) - res = xr.open_dataset(result_path) + res = load_data(result_path) print(res) - fig = plot_overview(res) + fig, plt.axes = plot_overview(res, figure_only=False) if len(sys.argv) > 2: fig.savefig(sys.argv[2], bbox_inches="tight") print(f"Saved figure to: {sys.argv[2]}") diff --git a/pyglotaran_extras/plotting/plot_residual.py b/pyglotaran_extras/plotting/plot_residual.py index 0000213c..5345a7df 100644 --- a/pyglotaran_extras/plotting/plot_residual.py +++ b/pyglotaran_extras/plotting/plot_residual.py @@ -1,7 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np +from pyglotaran_extras.plotting.style import PlotStyle + +if TYPE_CHECKING: + import xarray as xr + from cycler import Cycler + from matplotlib.axis import Axis + -def plot_residual(res, ax, linlog=False, linthresh=1, show_data=False): +def plot_residual( + res: xr.Dataset, + ax: Axis, + linlog: bool = False, + linthresh: float = 1, + show_data: bool = False, + cycler: Cycler = PlotStyle().cycler, +) -> None: + ax.set_prop_cycle(cycler) data = res.data if show_data else res.residual title = "dataset" if show_data else "residual" shape = np.array(data.shape) diff --git a/pyglotaran_extras/plotting/plot_spectra.py b/pyglotaran_extras/plotting/plot_spectra.py index dce1537c..f8838e0a 100644 --- a/pyglotaran_extras/plotting/plot_spectra.py +++ b/pyglotaran_extras/plotting/plot_spectra.py @@ -1,14 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np +from pyglotaran_extras.plotting.style import PlotStyle + +if TYPE_CHECKING: + import xarray as xr + from cycler import Cycler + from matplotlib.axis import Axis + from matplotlib.pyplot import Axes + -def plot_spectra(res, axes): +def plot_spectra(res: xr.Dataset, axes: Axes, cycler: Cycler = PlotStyle().cycler) -> None: plot_sas(res, axes[0, 0]) plot_das(res, axes[0, 1]) plot_norm_sas(res, axes[1, 0]) plot_norm_das(res, axes[1, 1]) -def plot_sas(res, ax, title="SAS"): +def plot_sas( + res: xr.Dataset, ax: Axis, title: str = "SAS", cycler: Cycler = PlotStyle().cycler +) -> None: + ax.set_prop_cycle(cycler) keys = [ v for v in res.data_vars if v.startswith(("species_associated_spectra", "species_spectra")) ] @@ -19,7 +34,10 @@ def plot_sas(res, ax, title="SAS"): ax.get_legend().remove() -def plot_norm_sas(res, ax, title="norm SAS"): +def plot_norm_sas( + res: xr.Dataset, ax: Axis, title: str = "norm SAS", cycler: Cycler = PlotStyle().cycler +) -> None: + ax.set_prop_cycle(cycler) keys = [ v for v in res.data_vars if v.startswith(("species_associated_spectra", "species_spectra")) ] @@ -31,7 +49,10 @@ def plot_norm_sas(res, ax, title="norm SAS"): ax.get_legend().remove() -def plot_das(res, ax, title="DAS"): +def plot_das( + res: xr.Dataset, ax: Axis, title: str = "DAS", cycler: Cycler = PlotStyle().cycler +) -> None: + ax.set_prop_cycle(cycler) keys = [ v for v in res.data_vars if v.startswith(("decay_associated_spectra", "species_spectra")) ] @@ -42,7 +63,10 @@ def plot_das(res, ax, title="DAS"): ax.get_legend().remove() -def plot_norm_das(res, ax, title="norm DAS"): +def plot_norm_das( + res: xr.Dataset, ax: Axis, title: str = "norm DAS", cycler: Cycler = PlotStyle().cycler +) -> None: + ax.set_prop_cycle(cycler) keys = [ v for v in res.data_vars if v.startswith(("decay_associated_spectra", "species_spectra")) ] diff --git a/pyglotaran_extras/plotting/plot_svd.py b/pyglotaran_extras/plotting/plot_svd.py index c5473295..be955577 100644 --- a/pyglotaran_extras/plotting/plot_svd.py +++ b/pyglotaran_extras/plotting/plot_svd.py @@ -1,14 +1,43 @@ -def plot_svd(res, axes, linlog=False, linthresh=1): - plot_lsv_residual(res, axes[0, 0], linlog=linlog, linthresh=linthresh) - plot_rsv_residual(res, axes[0, 1]) - plot_sv_residual(res, axes[0, 2]) - plot_lsv_data(res, axes[1, 0], linlog=linlog, linthresh=linthresh) - plot_rsv_data(res, axes[1, 1]) - plot_sv_data(res, axes[1, 2]) +from __future__ import annotations +from typing import TYPE_CHECKING -def plot_lsv_data(res, ax, indices=range(4), linlog=False, linthresh=1): +from pyglotaran_extras.plotting.style import PlotStyle + +if TYPE_CHECKING: + from typing import Sequence + + import xarray as xr + from cycler import Cycler + from matplotlib.axis import Axis + from matplotlib.pyplot import Axes + + +def plot_svd( + res: xr.Dataset, + axes: Axes, + linlog: bool = False, + linthresh: float = 1, + cycler: Cycler = PlotStyle().cycler, +) -> None: + plot_lsv_residual(res, axes[0, 0], linlog=linlog, linthresh=linthresh, cycler=cycler) + plot_rsv_residual(res, axes[0, 1], cycler=cycler) + plot_sv_residual(res, axes[0, 2], cycler=cycler) + plot_lsv_data(res, axes[1, 0], linlog=linlog, linthresh=linthresh, cycler=cycler) + plot_rsv_data(res, axes[1, 1], cycler=cycler) + plot_sv_data(res, axes[1, 2], cycler=cycler) + + +def plot_lsv_data( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(4), + linlog: bool = False, + linthresh: float = 1, + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot left singular vectors (time) of the data matrix""" + ax.set_prop_cycle(cycler) dLSV = res.data_left_singular_vectors dLSV.isel(left_singular_value_index=indices[: len(dLSV.left_singular_value_index)]).plot.line( x="time", ax=ax @@ -18,8 +47,14 @@ def plot_lsv_data(res, ax, indices=range(4), linlog=False, linthresh=1): ax.set_xscale("symlog", linthresh=linthresh) -def plot_rsv_data(res, ax, indices=range(4)): +def plot_rsv_data( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(4), + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot right singular vectors (spectra) of the data matrix""" + ax.set_prop_cycle(cycler) dRSV = res.data_right_singular_vectors dRSV.isel( right_singular_value_index=indices[: len(dRSV.right_singular_value_index)] @@ -27,8 +62,14 @@ def plot_rsv_data(res, ax, indices=range(4)): ax.set_title("data. RSV") -def plot_sv_data(res, ax, indices=range(10)): +def plot_sv_data( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(10), + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot singular values of the data matrix""" + ax.set_prop_cycle(cycler) dSV = res.data_singular_values dSV.sel(singular_value_index=indices[: len(dSV.singular_value_index)]).plot.line( "ro-", yscale="log", ax=ax @@ -36,8 +77,16 @@ def plot_sv_data(res, ax, indices=range(10)): ax.set_title("data. log(SV)") -def plot_lsv_residual(res, ax, indices=range(2), label="residual", linlog=False, linthresh=1): +def plot_lsv_residual( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(2), + linlog: bool = False, + linthresh: float = 1, + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot left singular vectors (time) of the residual matrix""" + ax.set_prop_cycle(cycler) if "weighted_residual_left_singular_vectors" in res: rLSV = res.weighted_residual_left_singular_vectors else: @@ -50,8 +99,14 @@ def plot_lsv_residual(res, ax, indices=range(2), label="residual", linlog=False, ax.set_xscale("symlog", linthresh=linthresh) -def plot_rsv_residual(res, ax, indices=range(2)): +def plot_rsv_residual( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(2), + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot right singular vectors (spectra) of the residual matrix""" + ax.set_prop_cycle(cycler) if "weighted_residual_right_singular_vectors" in res: rRSV = res.weighted_residual_right_singular_vectors else: @@ -62,8 +117,14 @@ def plot_rsv_residual(res, ax, indices=range(2)): ax.set_title("res. RSV") -def plot_sv_residual(res, ax, indices=range(10)): +def plot_sv_residual( + res: xr.Dataset, + ax: Axis, + indices: Sequence[int] = range(10), + cycler: Cycler = PlotStyle().cycler, +) -> None: """Plot singular values of the residual matrix""" + ax.set_prop_cycle(cycler) if "weighted_residual_singular_values" in res: rSV = res.weighted_residual_singular_values else: diff --git a/pyglotaran_extras/plotting/plot_traces.py b/pyglotaran_extras/plotting/plot_traces.py index 492e1a16..55668a0b 100644 --- a/pyglotaran_extras/plotting/plot_traces.py +++ b/pyglotaran_extras/plotting/plot_traces.py @@ -19,6 +19,7 @@ from typing import Iterable from cycler import Cycler + from matplotlib.axis import Axis from matplotlib.figure import Figure from matplotlib.pyplot import Axes @@ -28,7 +29,7 @@ def plot_data_and_fits( result: ResultLike, wavelength: float, - axis: Axes, + axis: Axis, center_λ: float | None = None, main_irf_nr: int = 0, linlog: bool = False, @@ -48,7 +49,7 @@ def plot_data_and_fits( Data structure which can be converted to a mapping. wavelength : float Wavelength to plot data and fits for. - axis : Axes + axis: Axis Axis to plot the data and fits on. center_λ: float | None Center wavelength (λ in nm) diff --git a/pyglotaran_extras/plotting/style.py b/pyglotaran_extras/plotting/style.py index 7aeec540..be7c9a19 100644 --- a/pyglotaran_extras/plotting/style.py +++ b/pyglotaran_extras/plotting/style.py @@ -26,12 +26,12 @@ class ColorCode(Enum): indigo = "#4b0082" @staticmethod - def hex_to_rgb(hex_string: str): + def hex_to_rgb(hex_string: str) -> tuple[int, ...]: rgb = colors.hex2color(hex_string) return tuple(int(255 * x) for x in rgb) @staticmethod - def rgb_to_hex(rgb_tuple: tuple[float, float, float]): + def rgb_to_hex(rgb_tuple: tuple[float, ...]) -> str: return colors.rgb2hex([1.0 * x / 255 for x in rgb_tuple]) diff --git a/setup.cfg b/setup.cfg index 49318176..c983d6b1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,11 @@ install_requires = python_requires = >=3.8,<3.10 zip_safe = True +[options.packages.find] +include = + pyglotaran_extras + pyglotaran_extras.* + [rstcheck] ignore_directives = autoattribute,autoclass,autoexception,autofunction,automethod,automodule,highlight @@ -48,3 +53,6 @@ convention = numpy ignore_missing_imports = True scripts_are_modules = True show_error_codes = True + +[mypy-pyglotaran_extras.*] +disallow_incomplete_defs = True diff --git a/tests/conftest.py b/tests/conftest.py index 6cd3dc91..46d4bfb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from pyglotaran_extras.io.boilerplate import get_script_dir +from pyglotaran_extras.io.setup_case_study import get_script_dir def wrapped_get_script_dir(): diff --git a/tests/deprecation/__init__.py b/tests/deprecation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/deprecation/modules/__init__.py b/tests/deprecation/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/deprecation/modules/test_io_boilerplate.py b/tests/deprecation/modules/test_io_boilerplate.py new file mode 100644 index 00000000..5ba1fb5b --- /dev/null +++ b/tests/deprecation/modules/test_io_boilerplate.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning + + +def test_io_boilerplate(): + """Importing from ``pyglotaran_extras.io.boilerplate`` raises deprecation warning.""" + with pytest.warns(PyglotaranExtrasApiDeprecationWarning) as record: + from pyglotaran_extras.io.boilerplate import setup_case_study # noqa:F401 + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + assert "'pyglotaran_extras.io'" in record[0].message.args[0] diff --git a/tests/deprecation/modules/test_plotting_data.py b/tests/deprecation/modules/test_plotting_data.py new file mode 100644 index 00000000..fb38a032 --- /dev/null +++ b/tests/deprecation/modules/test_plotting_data.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning + + +def test_io_boilerplate(): + """Importing from ``pyglotaran_extras.plotting.data`` raises deprecation warning.""" + with pytest.warns(PyglotaranExtrasApiDeprecationWarning) as record: + from pyglotaran_extras.plotting.data import plot_data_overview # noqa:F401 + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + assert "'pyglotaran_extras.plotting.data'" in record[0].message.args[0] diff --git a/tests/deprecation/test_deprecation_utils.py b/tests/deprecation/test_deprecation_utils.py new file mode 100644 index 00000000..15416eb4 --- /dev/null +++ b/tests/deprecation/test_deprecation_utils.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +import pyglotaran_extras +from pyglotaran_extras.deprecation.deprecation_utils import OverDueDeprecation +from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning +from pyglotaran_extras.deprecation.deprecation_utils import check_overdue +from pyglotaran_extras.deprecation.deprecation_utils import parse_version +from pyglotaran_extras.deprecation.deprecation_utils import pyglotaran_extras_version +from pyglotaran_extras.deprecation.deprecation_utils import warn_deprecated + +if TYPE_CHECKING: + + from _pytest.monkeypatch import MonkeyPatch + +OVERDUE_ERROR_MESSAGE = ( + "Support for 'pyglotaran_extras.deprecation.deprecation_utils.parse_version' " + "was supposed to be dropped in version: '0.6.0'.\n" + "Current version is: '1.0.0'" +) + +DEP_UTILS_QUALNAME = "pyglotaran_extras.deprecation.deprecation_utils" + +DEPRECATION_QUAL_NAME = f"{DEP_UTILS_QUALNAME}.parse_version(version_str)" +NEW_QUAL_NAME = f"{DEP_UTILS_QUALNAME}.check_overdue(qualnames)" + +DEPRECATION_WARN_MESSAGE = ( + "Usage of 'pyglotaran_extras.deprecation.deprecation_utils.parse_version(version_str)' " + "was deprecated, use " + "'pyglotaran_extras.deprecation.deprecation_utils.check_overdue(qualnames)' " + "instead.\nThis usage will be an error in version: '0.6.0'." +) + + +@pytest.fixture +def pyglotaran_extras_0_3_0(monkeypatch: MonkeyPatch): + """Mock pyglotaran_extras version to always be 0.3.0 for the test.""" + monkeypatch.setattr( + pyglotaran_extras.deprecation.deprecation_utils, # type:ignore[attr-defined] + "pyglotaran_extras_version", + lambda: "0.3.0", + ) + yield + + +@pytest.fixture +def pyglotaran_extras_1_0_0(monkeypatch: MonkeyPatch): + """Mock pyglotaran_extras version to always be 1.0.0 for the test.""" + monkeypatch.setattr( + pyglotaran_extras.deprecation.deprecation_utils, # type:ignore[attr-defined] + "pyglotaran_extras_version", + lambda: "1.0.0", + ) + yield + + +def test_pyglotaran_extras_version(): + """Versions are the same.""" + assert pyglotaran_extras_version() == pyglotaran_extras.__version__ + + +@pytest.mark.parametrize( + "version_str, expected", + ( + ("0.0.1", (0, 0, 1)), + ("0.0.1.post", (0, 0, 1)), + ("0.0.1-dev", (0, 0, 1)), + ("0.0.1-dev.post", (0, 0, 1)), + ), +) +def test_parse_version(version_str: str, expected: tuple[int, int, int]): + """Valid version strings.""" + assert parse_version(version_str) == expected + + +@pytest.mark.parametrize( + "version_str", + ("1", "0.1", "a.b.c"), +) +def test_parse_version_errors(version_str: str): + """Invalid version strings""" + with pytest.raises(ValueError, match=f"'{version_str}'"): + parse_version(version_str) + + +@pytest.mark.usefixtures("pyglotaran_extras_0_3_0") +def test_check_overdue_no_raise(monkeypatch: MonkeyPatch): + """Current version smaller then drop_version.""" + check_overdue( + deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, + to_be_removed_in_version="0.6.0", + ) + + +@pytest.mark.usefixtures("pyglotaran_extras_1_0_0") +def test_check_overdue_raises(monkeypatch: MonkeyPatch): + """Current version is equal or bigger than drop_version.""" + with pytest.raises(OverDueDeprecation) as excinfo: + check_overdue( + deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, + to_be_removed_in_version="0.6.0", + ) + + assert str(excinfo.value) == OVERDUE_ERROR_MESSAGE + + +@pytest.mark.usefixtures("pyglotaran_extras_0_3_0") +def test_warn_deprecated(): + """Warning gets shown when all is in order.""" + with pytest.warns(PyglotaranExtrasApiDeprecationWarning) as record: + warn_deprecated( + deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, + new_qual_name_usage=NEW_QUAL_NAME, + to_be_removed_in_version="0.6.0", + ) + + assert len(record) == 1 + assert record[0].message.args[0] == DEPRECATION_WARN_MESSAGE + assert Path(record[0].filename) == Path(__file__) + + +@pytest.mark.usefixtures("pyglotaran_extras_1_0_0") +def test_warn_deprecated_overdue_deprecation(monkeypatch: MonkeyPatch): + """Current version is equal or bigger than drop_version.""" + + with pytest.raises(OverDueDeprecation) as excinfo: + warn_deprecated( + deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, + new_qual_name_usage=NEW_QUAL_NAME, + to_be_removed_in_version="0.6.0", + ) + assert str(excinfo.value) == OVERDUE_ERROR_MESSAGE + + +@pytest.mark.filterwarnings("ignore:Usage") +@pytest.mark.xfail(strict=True, reason="Dev version aren't checked") +def test_warn_deprecated_no_overdue_deprecation_on_dev(monkeypatch: MonkeyPatch): + """Current version is equal or bigger than drop_version but it's a dev version.""" + monkeypatch.setattr( + pyglotaran_extras.deprecation.deprecation_utils, # type:ignore[attr-defined] + "glotaran_version", + lambda: "0.6.0-dev", + ) + + with pytest.raises(OverDueDeprecation): + warn_deprecated( + deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, + new_qual_name_usage=NEW_QUAL_NAME, + to_be_removed_in_version="0.6.0", + ) diff --git a/tests/io/test_get_script_dir.ipynb b/tests/io/test_get_script_dir.ipynb index 181b50ce..477df594 100644 --- a/tests/io/test_get_script_dir.ipynb +++ b/tests/io/test_get_script_dir.ipynb @@ -9,7 +9,7 @@ "source": [ "from pathlib import Path\n", "\n", - "from pyglotaran_extras.io.boilerplate import get_script_dir" + "from pyglotaran_extras.io.setup_case_study import get_script_dir" ] }, { diff --git a/tests/io/test_boilerplate.py b/tests/io/test_setup_case_study.py similarity index 88% rename from tests/io/test_boilerplate.py rename to tests/io/test_setup_case_study.py index 973a0fcd..218e8c0e 100644 --- a/tests/io/test_boilerplate.py +++ b/tests/io/test_setup_case_study.py @@ -8,8 +8,8 @@ from tests.conftest import wrapped_get_script_dir -from pyglotaran_extras.io.boilerplate import get_script_dir -from pyglotaran_extras.io.boilerplate import setup_case_study +from pyglotaran_extras.io.setup_case_study import get_script_dir +from pyglotaran_extras.io.setup_case_study import setup_case_study if TYPE_CHECKING: from _pytest.monkeypatch import MonkeyPatch @@ -34,7 +34,7 @@ def test_get_script_dir_tmp_path(tmp_path: Path): tmp_file = tmp_path / "foo.py" content = dedent( """ - from pyglotaran_extras.io.boilerplate import get_script_dir + from pyglotaran_extras.io.setup_case_study import get_script_dir print(get_script_dir()) """ ) @@ -44,6 +44,7 @@ def test_get_script_dir_tmp_path(tmp_path: Path): ) result = printed_result.stdout.decode().rstrip("\n\r") + assert printed_result.returncode == 0 assert Path(result) == tmp_path.resolve()