diff --git a/.editorconfig b/.editorconfig index 815ec44d..7ea21bb9 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ insert_final_newline = true charset = utf-8 end_of_line = lf -[*.py] +[*.{py,ipynb}] indent_size = 4 [*.bat] diff --git a/.github/requirements_min.txt b/.github/requirements_min.txt index e622faaa..ac3f066d 100644 --- a/.github/requirements_min.txt +++ b/.github/requirements_min.txt @@ -2,7 +2,10 @@ cycler==0.10 numpy==1.22 +docstring-parser==0.16 matplotlib==3.3 +pydantic==2.0.0 pyglotaran==0.7.2 +ruamel-yaml==0.18.6 tabulate==0.8.9 xarray==2022.3 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 38abc694..71aef958 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -108,10 +108,10 @@ jobs: run: | python -m pip install -U pip wheel python -m pip install -r requirements_pinned.txt - python -m pip install -U -e ".[test]" + python -m pip install -U ".[test]" - name: Run tests - run: python -m pytest --nbval --cov=./ --cov-report term --cov-report xml --cov-config pyproject.toml tests + run: python -m pytest --nbval --cov=pyglotaran_extras --cov-report term --cov-report xml --cov-config pyproject.toml tests - name: Codecov Upload continue-on-error: true @@ -140,14 +140,14 @@ jobs: run: | python -m pip install -U pip wheel python -m pip install -r requirements_pinned.txt - python -m pip install -U -e ".[test]" + python -m pip install -U ".[test]" python -m pip install git+https://github.com/glotaran/pyglotaran - name: Show installed dependencies run: pip freeze - name: Run tests - run: python -m pytest --nbval --cov=./ --cov-report term --cov-report xml --cov-config pyproject.toml tests + run: python -m pytest --nbval --cov=pyglotaran_extras --cov-report term --cov-report xml --cov-config pyproject.toml tests - name: Codecov Upload continue-on-error: true diff --git a/.gitignore b/.gitignore index 74bcc4d8..81a1ad2e 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,8 @@ profile.out # Distribution / packaging .Python env/ +venv/ +.venv/ build/ develop-eggs/ dist/ @@ -93,10 +95,14 @@ coverage.xml # Sphinx documentation docs/_build/ docs/api/ +docs/_static # documents generated by Sphinx.ext.autosummary docs/source/user_documentation/api/* # doc figures docs/source/images/plot +# Files generated by the config docs +pygta_config.schema.json +docs/config/project/subproject/pygta_config.yml # PyBuilder target/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5e625e6..772b24f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: args: [--in-place, --config, ./pyproject.toml] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.1.3" + rev: "2.2.1" hooks: - id: pyproject-fmt @@ -46,7 +46,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.7 + rev: v0.6.2 hooks: - id: ruff name: "ruff sort imports notebooks" @@ -72,11 +72,11 @@ repos: # Linters - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.11.1 hooks: - id: mypy exclude: ^docs - additional_dependencies: [types-tabulate] + additional_dependencies: [types-tabulate, pydantic] - repo: https://github.com/econchick/interrogate rev: 1.7.0 @@ -88,7 +88,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.7 + rev: v0.6.2 hooks: - id: ruff name: "ruff sort imports" @@ -103,20 +103,20 @@ repos: name: "ruff lint" - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 alias: flake8-docs args: - "--select=DOC" - - "--extend-ignore=DOC502" + - "--extend-ignore=DOC502,DOC601,DOC603" - "--color=always" - "--require-return-section-when-returning-nothing=False" - "--allow-init-docstring=True" - "--skip-checking-short-docstrings=False" name: "flake8 lint docstrings" exclude: "^(docs/|tests?/)" - additional_dependencies: [pydoclint==0.3.8] + additional_dependencies: [pydoclint==0.5.6] - repo: https://github.com/codespell-project/codespell rev: v2.3.0 diff --git a/.ruff-notebooks.toml b/.ruff-notebooks.toml index 79fc512b..d2f3db5c 100644 --- a/.ruff-notebooks.toml +++ b/.ruff-notebooks.toml @@ -1,7 +1,18 @@ #:schema https://json.schemastore.org/ruff.json extend = ".ruff.toml" -extend-ignore = ["D", "E402", "F404"] -[isort] +extend-exclude = [ + "docs/conf.py", +] + +[lint] +extend-ignore = [ + "D", + "E402", + "F404", + "I002", # from __future__ import annotations +] + +[lint.isort] required-imports = [] force-single-line = false diff --git a/.ruff.toml b/.ruff.toml index aa23960d..2da1d052 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,7 +1,7 @@ # Exclude a variety of commonly ignored directories. extend-exclude = [ - "venv", "docs/conf.py", + "*.ipynb", ] # Same as Black. line-length = 99 @@ -9,32 +9,30 @@ line-length = 99 # Assume Python 3.10. target-version = "py310" -# Enable using ruff with notebooks -extend-include = ["*.ipynb"] [lint] select = [ - "E", # pycodestyle - "W", # pycodestyle - "C", # mccabe - "F", # pyflakes - "UP", # pyupgrade - "D", # pydocstyle - "N", # pep8-naming + "E", # pycodestyle + "W", # pycodestyle + "C", # mccabe + "F", # pyflakes + "UP", # pyupgrade + "D", # pydocstyle + "N", # pep8-naming "YTT", # flake8-2020 "BLE", # flake8-blind-except # "FBT", # flake8-boolean-trap - "B", # flake8-bugbear - "C4", # flake8-comprehensions + "B", # flake8-bugbear + "C4", # flake8-comprehensions "T10", # flake8-debugger - "FA", # flake8-future-annotations - "EM", # flake8-errmsg - "I", # isort (activates import sorting for formatter) + "FA", # flake8-future-annotations + "EM", # flake8-errmsg + "I", # isort (activates import sorting for formatter) "ISC", # flake8-implicit-str-concat "INP", # flake8-no-pep420 "PIE", # flake8-pie "T20", # flake8-print - "PT", # flake8-pytest-style + "PT", # flake8-pytest-style "RSE", # flake8-raise "RET", # flake8-return "SIM", # flake8-simplify @@ -42,7 +40,7 @@ select = [ "ARG", # flake8-unused-arguments "PTH", # flake8-use-pathlib "ERA", # eradicate - "PD", # pandas-vet + "PD", # pandas-vet "PGH", # pygrep-hooks "NPY", # NumPy-specific "RUF", # Ruff-specific @@ -60,13 +58,17 @@ ignore = [ # Covered by formatter "ISC001", ] +external = ["DOC"] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [lint.per-file-ignores] "tests/*" = ["ARG001"] - +"tests/data/*" = ["INP", "D"] +"tests/data/config/run_load_config_on_import.py" = [ + "I002", # from __future__ import annotations +] [lint.isort] required-imports = ["from __future__ import annotations"] known-first-party = ["pyglotaran_extras"] diff --git a/docs/conf.py b/docs/conf.py index e2ec91ea..fa98fb96 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,9 +17,13 @@ # relative to the documentation root, use os.path.abspath to make it # absolute, like shown here. # - +from pathlib import Path import pyglotaran_extras +HERE = Path(__file__).parent + +pyglotaran_extras.create_config_schema(HERE/"_static") + # -- General configuration --------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -33,23 +37,31 @@ "sphinx.ext.autosummary", "sphinx.ext.viewcode", "sphinx.ext.napoleon", - "myst_parser", + "myst_nb", + 'sphinxcontrib.mermaid', + "sphinx_copybutton", "sphinx_rtd_theme", ] +myst_fence_as_directive = ["mermaid"] + autoclass_content = "both" autosummary_generate = True add_module_names = False autodoc_member_order = "bysource" +autodoc_pydantic_model_show_config_summary=False +autodoc_pydantic_model_show_validator_summary=False + + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = [".rst", ".md"] +source_suffix = {'.rst': 'restructuredtext', '.md': 'restructuredtext'} # source_suffix = '.rst' linkcheck_ignore = [ r"https://github\.com/glotaran/pyglotaran-extras/actions", @@ -111,7 +123,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ["_static"] +html_static_path = ["_static"] # -- Options for HTMLHelp output --------------------------------------- diff --git a/docs/config/project/fs_config.yml b/docs/config/project/fs_config.yml new file mode 100644 index 00000000..eb5a8ef2 --- /dev/null +++ b/docs/config/project/fs_config.yml @@ -0,0 +1,4 @@ +plotting: + general: + axis_label_override: + time: "Time (fs)" diff --git a/docs/config/project/pygta_config.yml b/docs/config/project/pygta_config.yml new file mode 100644 index 00000000..88c8a1fa --- /dev/null +++ b/docs/config/project/pygta_config.yml @@ -0,0 +1,18 @@ +plotting: + general: + default_args_override: + linlog: true + use_svd_number: true + linthresh: 2 + axis_label_override: + time: "Time (ps)" + spectral: "Wavelength (nm)" + data_left_singular_vectors: "" + data_singular_values: "Singular Value (a.u.)" + data_right_singular_vectors: "" + plot_svd: + default_args_override: + use_svd_number: false + plot_overview: + default_args_override: + show_data: True diff --git a/docs/config/project/subproject/config_docs.ipynb b/docs/config/project/subproject/config_docs.ipynb new file mode 100644 index 00000000..4f3f1cfb --- /dev/null +++ b/docs/config/project/subproject/config_docs.ipynb @@ -0,0 +1,549 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Configuration\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "Let's assume you started a new project and created a [`jupyter notebook`](https://jupyter.org/)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%ls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize a new project\n", + "\n", + "The recommended way is to import the `CONFIG` and initialize a new project. However this isn't required just recommended.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyglotaran_extras import CONFIG\n", + "\n", + "CONFIG.init_project()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will look up all config files in your home folder, the notebook folders parent folder,\n", + "and the notebook folder, combine them and create a new config and schema file in the notebook\n", + "folder for you, as well as rediscovering and reloading the config (see [file-lookup](#file-lookup)).\n", + "\n", + "```{note}\n", + "If a config file already exists, the file creation will be skipped in order to not overwrite an\n", + "exported custom schema with your own plot functions.\n", + "```\n", + "\n", + "```{admonition} Tip\n", + "If you don't want the config to be shown in the cell output, just add a `;` after `CONFIG.init_project()`.\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%ls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Discovering and loading config files\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to only work with one config file you can simply load it.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG.load(\"../fs_config.yml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you don't like the way config files are looked up you can manually rediscover them and reload the config.\n", + "\n", + "```{note}\n", + "Note that the reload is only used for demonstration purposes, since the config is autoreloaded before being used (see [auto-reload](#auto-reload))\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG.rediscover(include_home_dir=False, lookup_depth=3)\n", + "CONFIG.reload()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### How the config affects plotting\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To demonstrate the difference between not using the config and using the config we create a copy of our `project_config` as well as an `empty_config` (same as not having a config at all).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyglotaran_extras.config.config import Config\n", + "\n", + "project_config = CONFIG.model_copy(deep=True)\n", + "empty_config = Config()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Default plotting behavior\n", + "\n", + "By default plots don't do renaming to make it easier to find the underlying data in the dataset.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from glotaran.testing.simulated_data.parallel_spectral_decay import DATASET\n", + "\n", + "from pyglotaran_extras import plot_data_overview\n", + "\n", + "CONFIG._reset(empty_config)\n", + "\n", + "plot_data_overview(DATASET);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Manually adjusting the Plot\n", + "\n", + "So in order to make you plots ready for a publication you have to set all the labels and\n", + "add plot function arguments each time you call it, and keeping things in sync for all plots you generate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plot_data_overview(DATASET, linlog=True, linthresh=2, use_svd_number=True)\n", + "axes[0].set_xlabel(\"Time (ps)\")\n", + "axes[0].set_ylabel(\"Wavelength (nm)\")\n", + "axes[1].set_xlabel(\"Time (ps)\")\n", + "axes[1].set_ylabel(\"\")\n", + "axes[2].set_ylabel(\"Singular Value (a.u.)\")\n", + "axes[3].set_xlabel(\"Wavelength (nm)\")\n", + "axes[3].set_ylabel(\"\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Using the plot config\n", + "\n", + "The same as with manually changing your plots and function arguments can be achieved with plot config,\n", + "but it is way less code, keeps all plots in sync for you and spares you from copy pasting the same things all\n", + "over the place.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG._reset(project_config)\n", + "\n", + "plot_data_overview(DATASET);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Temporarily changing the config\n", + "\n", + "Let's assume that one dataset uses wavenumbers instead of wavelength as spectral axis.\n", + "\n", + "You can simply define a `PerFunctionPlotConfig` and call your plot function inside of a `plot_config_context`.\n", + "This way you can even override function specific configuration defined in your config file.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyglotaran_extras import PerFunctionPlotConfig\n", + "from pyglotaran_extras import plot_config_context\n", + "\n", + "my_plot_config = PerFunctionPlotConfig(\n", + " axis_label_override={\"spectral\": \"Wavenumber (cm$^{-1}$)\"}\n", + ")\n", + "\n", + "with plot_config_context(my_plot_config):\n", + " plot_data_overview(DATASET)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using the config for you own function\n", + "\n", + "The plot config isn't just for our builtin functions but you can also use it with your own custom \n", + "functions.\n", + "\n", + "```{note}\n", + "For axes label changing to work with you function the function needs to either take them as argument \n", + "or return them.\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyglotaran_extras import use_plot_config\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "@use_plot_config()\n", + "def my_plot(swap_axis=False):\n", + " fig, ax = plt.subplots()\n", + " ax.set_xlabel(\"x\")\n", + " ax.set_ylabel(\"y\")\n", + " x = np.linspace(-10,10)\n", + " y = x**2\n", + " if swap_axis is True:\n", + " x,y = y,x\n", + " ax.plot(x,y,)\n", + " return fig, ax\n", + "\n", + "\n", + "my_plot();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For quick prototyping of our config we will just use `PerFunctionPlotConfig` and `plot_config_context`\n", + "from the previous section. \n", + "\n", + "```{note}\n", + "If you aren't writing documentation you can just export the config to update the json schema and \n", + "change the file directly including editor support 😅.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_plot_config = PerFunctionPlotConfig(\n", + " axis_label_override={\"x\":\"x-axis\",\"y\":\"y-axis\"},\n", + " default_args_override={\"swap_axis\":True}\n", + ")\n", + "\n", + "with plot_config_context(my_plot_config):\n", + " my_plot();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we are happy with the config we can just look at the corresponding yaml and \n", + "copy paste it into a new `my_plot` section inside of the `plotting` section in the config." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_plot_config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we can export the config that is aware of our new function `my_plot`, which will:\n", + "- Update the existing config (nothing to do in this case)\n", + "- Update the schema file to know about `my_plot`\n", + "\n", + "So the next time we change something in our config it will be able to autocomplete and lint our the content." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG.export()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FAQ\n", + "\n", + "### Do I have to use the config?\n", + "\n", + "No. Using the config is fully optional, however we recommend using it since it reduces the amount of\n", + "code you need to write and lets anybody reading your analysis focus on the science rather than the\n", + "python code used to make your plots.\n", + "\n", + "### What can the configuration be used for?\n", + "\n", + "The main goal of the config is to configure plot functions and reduce tedious code duplication like:\n", + "\n", + "- Renaming labels of axes\n", + "- Overriding default values to plot function calls\n", + "\n", + "We try to have sensible default values for our plot functions, but there is no `one fits all` solution.\n", + "\n", + "Especially since arguments like `linthresh` (determines the range in which a `linlog` plot is linear)\n", + "are highly dependent on your data.\n", + "\n", + "Thus we give you the power to customize the default values to your projects needs, without having\n", + "repeating them over and over each time you call a plot function.\n", + "\n", + "### Can I still change plot labels myself?\n", + "\n", + "Yes, the config gets applied when a config enabled plot function is called you can still\n", + "work with the return figure and axes as you are used to be.\n", + "\n", + "### Does using a config mean arguments I pass to a function get ignored?\n", + "\n", + "No, arguments from the config are only used you don't pass an argument.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(file-lookup)=\n", + "\n", + "### How are config files looked up?\n", + "\n", + "When you import anything from `pyglotaran_extras` the location of your project is determined.\n", + "This location then is used to look for `pygta_config.yaml` and `pygta_config.yml` in the following folders:\n", + "\n", + "- Your user home directory\n", + "- The projects parent folder\n", + "- The project folder\n", + "\n", + "If you don't want to include your home folder or a different lookup depth relative to your project\n", + "folder you can use `CONFIG.rediscover`.\n", + "If you only want to load the config from a single file you can use `CONFIG.load`.\n", + "\n", + "(auto-reload)=\n", + "\n", + "### Do I need to reload the config after changing a file?\n", + "\n", + "No, the config keeps track of when each config file was last modified and automatically reloads if needed.\n", + "\n", + "(value-determination)=\n", + "\n", + "### How is determined what config values to use?\n", + "\n", + "The config follows the locality and specificity principles.\n", + "\n", + "#### Locality\n", + "\n", + "Locality here means that the closer the configuration is to the plot function call the higher its importance.\n", + "\n", + "Lets consider the example of the default behavior where configs are looked up in the home directory,\n", + "projects parent folder and project folder.\n", + "When the global `CONFIG` instance is loaded it merges the configs in the following order:\n", + "\n", + "- Your user home directory\n", + "- The projects parent folder\n", + "- The project folder\n", + "\n", + "Where each merge overwrites duplicate values from the config it gets merged into.\n", + "\n", + "#### Specificity\n", + "\n", + "For ease of use and reduced duplications, the plot config has a `general` section\n", + "that applies to a plot function with use those arguments or plot labels.\n", + "\n", + "Lets assume that your experimental data use time in picoseconds (ps) and wavelengths in nanometers (nm).\n", + "Instead of a defining the label override for each function you can simply it to the general section as\n", + "see above and if a function doesn't have it defined itself it also gets applied for this function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG.plotting.get_function_config(\"plot_svd\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To demonstrate the effects on the config we will reuse `wavenumber_config` for the usage example.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with plot_config_context(my_plot_config):\n", + " plot_svd_config_wavenumber = CONFIG.plotting.get_function_config(\"plot_svd\")\n", + "plot_svd_config_wavenumber" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This change is only valid inside of the `plot_config_context` and reset afterwards\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Which arguments and label are used are defined by the following hierarchy.\n", + "\n", + "- Plot function arguments\n", + "- `plot_config_context`\n", + "- Global `CONFIG.plotting` instance `function config`\n", + "- Global `CONFIG.plotting` instance `general`\n", + "\n", + "````{note}\n", + "For compound functions like `plot_overview` which consist of multiple plot config enabled functions\n", + "the `default_args_override` for `plot_overview` will be passed down to the other functions and\n", + "override their usage of own `default_args_override` config (if arguments are passed they aren't\n", + "default arguments anymore 😅).\n", + "Where as `axis_label_override` for the functions config is first applied to the intermediate plots\n", + "and `axis_label_override` from `plot_overview` is only applied after that on final plot.\n", + "\n", + "```mermaid\n", + "graph TD\n", + " A[plot_overview] --> |\"default_args_override (plot_overview)\"| B[plot_svd]\n", + " B --> |\"axis_label_override (plot_svd)\"| C[intermediate plot]\n", + " C --> |\"axis_label_override (plot_overview)\"| D[final plot]\n", + "```\n", + "````\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG.plotting.get_function_config(\"plot_svd\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What is the `pygta_config.schema.json` file for?\n", + "\n", + "TLDR; It enables autocomplete and error detection in your editor.\n", + "\n", + "[JSON-schema](https://json-schema.org/) is a format that is used to describe data structures\n", + "including their types in a language agnostic way.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyglotaran310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/environment.yml b/docs/environment.yml new file mode 100644 index 00000000..cec84d47 --- /dev/null +++ b/docs/environment.yml @@ -0,0 +1,11 @@ +name: RTD-env +channels: + - conda-forge +dependencies: + # Python interpreter + - python=3.10 + - pandoc>=3.2.1 + - pip + - pip: + - -r ../requirements_pinned.txt + - ..[docs] diff --git a/docs/index.md b/docs/index.md index 9ee9fd8b..0920e11c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,6 +7,7 @@ readme installation usage +config/project/subproject/config_docs api_docs contributing changelog diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index f90165d8..2ea1c4c6 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -2,15 +2,35 @@ from __future__ import annotations +from pyglotaran_extras.config.config import _find_script_dir_at_import +from pyglotaran_extras.config.config import create_config_schema +from pyglotaran_extras.config.config import load_config +from pyglotaran_extras.config.plot_config import PerFunctionPlotConfig +from pyglotaran_extras.config.plot_config import plot_config_context +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.io.setup_case_study import setup_case_study from pyglotaran_extras.plotting.plot_coherent_artifact import plot_coherent_artifact +from pyglotaran_extras.plotting.plot_concentrations import plot_concentrations from pyglotaran_extras.plotting.plot_data import plot_data_overview from pyglotaran_extras.plotting.plot_doas import plot_doas from pyglotaran_extras.plotting.plot_guidance import plot_guidance from pyglotaran_extras.plotting.plot_irf_dispersion_center import plot_irf_dispersion_center from pyglotaran_extras.plotting.plot_overview import plot_overview from pyglotaran_extras.plotting.plot_overview import plot_simple_overview +from pyglotaran_extras.plotting.plot_residual import plot_residual +from pyglotaran_extras.plotting.plot_spectra import plot_das +from pyglotaran_extras.plotting.plot_spectra import plot_norm_das +from pyglotaran_extras.plotting.plot_spectra import plot_norm_sas +from pyglotaran_extras.plotting.plot_spectra import plot_sas +from pyglotaran_extras.plotting.plot_spectra import plot_spectra +from pyglotaran_extras.plotting.plot_svd import plot_lsv_data +from pyglotaran_extras.plotting.plot_svd import plot_lsv_residual +from pyglotaran_extras.plotting.plot_svd import plot_rsv_data +from pyglotaran_extras.plotting.plot_svd import plot_rsv_residual +from pyglotaran_extras.plotting.plot_svd import plot_sv_data +from pyglotaran_extras.plotting.plot_svd import plot_sv_residual +from pyglotaran_extras.plotting.plot_svd import plot_svd from pyglotaran_extras.plotting.plot_traces import plot_fitted_traces from pyglotaran_extras.plotting.plot_traces import select_plot_wavelengths from pyglotaran_extras.plotting.utils import add_subplot_labels @@ -19,15 +39,40 @@ "load_data", "setup_case_study", "plot_coherent_artifact", + "plot_concentrations", "plot_data_overview", "plot_doas", "plot_guidance", "plot_irf_dispersion_center", "plot_overview", "plot_simple_overview", + "plot_residual", + "plot_das", + "plot_norm_das", + "plot_norm_sas", + "plot_sas", + "plot_spectra", + "plot_lsv_data", + "plot_lsv_residual", + "plot_rsv_data", + "plot_rsv_residual", + "plot_sv_data", + "plot_sv_residual", + "plot_svd", "plot_fitted_traces", "select_plot_wavelengths", "add_subplot_labels", + # Config + "PerFunctionPlotConfig", + "plot_config_context", + "use_plot_config", + "create_config_schema", + "CONFIG", ] __version__ = "0.7.2" + +SCRIPT_DIR = _find_script_dir_at_import(__file__) +"""User script dir determined during import.""" +CONFIG = load_config(SCRIPT_DIR) +"""Global config instance.""" diff --git a/pyglotaran_extras/config/__init__.py b/pyglotaran_extras/config/__init__.py new file mode 100644 index 00000000..56096f22 --- /dev/null +++ b/pyglotaran_extras/config/__init__.py @@ -0,0 +1 @@ +"""Configuration package.""" diff --git a/pyglotaran_extras/config/config.py b/pyglotaran_extras/config/config.py new file mode 100644 index 00000000..698edd43 --- /dev/null +++ b/pyglotaran_extras/config/config.py @@ -0,0 +1,470 @@ +"""Module containing configuration.""" + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import PrivateAttr +from pydantic import PydanticUserError +from pydantic import create_model +from pydantic.fields import FieldInfo +from ruamel.yaml import YAML + +from pyglotaran_extras.config.plot_config import PlotConfig +from pyglotaran_extras.config.plot_config import PlotLabelOverrideMap +from pyglotaran_extras.config.plot_config import __PlotFunctionRegistry +from pyglotaran_extras.config.utils import add_yaml_repr +from pyglotaran_extras.io.setup_case_study import get_script_dir + +if TYPE_CHECKING: + from collections.abc import Generator + from collections.abc import Iterable + +# Only imported for builtin schema generation +from collections.abc import Sequence # noqa: F401 +from typing import Literal # noqa: F401 + +CONFIG_FILE_STEM = "pygta_config" + +EXPORT_TEMPLATE = """\ +# yaml-language-server: $schema={schema_path} + +{config_yaml}\ +""" + + +class UsePlotConfigError(Exception): + """Error thrown when ``use_plot_config`` has none json serializable kwargs.""" + + def __init__(self, func_name: str, error: PydanticUserError) -> None: # noqa: DOC + """Use ``func_name`` and original ``error`` to create error message.""" + msg = ( + f"The function ``{func_name}`` decorated with ``use_plot_config`` has an keyword " + "argument with a type annotation can not be represents in the config.\n" + "Please use the name of this keyword argument in the ``exclude_from_config`` " + "keyword argument to ``use_plot_config``.\n" + f"Original error:\n{error}" + ) + super().__init__(msg) + + +@add_yaml_repr +class Config(BaseModel): + """Main configuration class.""" + + model_config = ConfigDict(extra="forbid") + + plotting: PlotConfig = PlotConfig() + _source_files: list[Path] = PrivateAttr(default_factory=list) + _source_hash: int = PrivateAttr(default=hash(())) + + def merge(self, other: Config) -> Config: + """Merge two ``Config``'s where ``other`` overrides values and return a new instance. + + Parameters + ---------- + other : Config + Other ``Config`` to merge in. + + Returns + ------- + Config + """ + merged = self.model_copy(deep=True) + merged.plotting = merged.plotting.merge(other.plotting) + for source_file in other._source_files: + if source_file in merged._source_files: + merged._source_files.remove(source_file) + merged._source_files.append(source_file) + merged._source_hash = merged._calculate_source_hash() + return merged + + def _reset(self, other: Config | None = None) -> Config: + """Reset self to ``other`` config or default initialization. + + Parameters + ---------- + other : Config | None + Other ``Config`` to to reset to. + + Returns + ------- + Config + """ + if other is None: + other = Config() + else: + self._source_files = other._source_files + self.plotting = other.plotting + return self + + def _calculate_source_hash(self) -> int: # noqa: DOC + """Calculate hash of source files based on their modification time.""" + return hash(tuple(source_file.stat().st_mtime for source_file in self._source_files)) + + def reload(self) -> Config: + """Reset and reload config from files. + + Returns + ------- + Config + """ + if self._source_hash == self._calculate_source_hash(): + return self + context_config = getattr(self.plotting, "__context_config", None) + merged = self._reset() + for config in load_config_files(self._source_files): + merged = merged.merge(config) + self.plotting = merged.plotting + if context_config is not None: + setattr(self.plotting, "__context_config", context_config) + self._source_hash = merged._source_hash + return self + + def load(self, config_file_path: Path | str) -> Config: + """Disregard current config and config file paths, and reload from ``config_file_path``. + + Parameters + ---------- + config_file_path : Path | str + Path to the config file to load. + + Returns + ------- + Config + """ + self._source_files = [Path(config_file_path)] + return self.reload() + + def export(self, export_folder: Path | str | None = None, *, update: bool = True) -> Path: + """Export current config and schema to ``export_folder``. + + Parameters + ---------- + export_folder : Path | str | None + Folder to export config and scheme to. Defaults to None, which means that the script + folder is used + update : bool + Whether to update or overwrite and existing config file. Defaults to True + + Returns + ------- + Path + Path to exported config file. + """ + if export_folder is None: + from pyglotaran_extras import SCRIPT_DIR + + export_folder = SCRIPT_DIR + else: + export_folder = Path(export_folder) + export_folder.mkdir(parents=True, exist_ok=True) + schema_path = create_config_schema(export_folder) + export_path = export_folder / f"{CONFIG_FILE_STEM}.yml" + if export_path.is_file() is True and update is True: + merged = Config().load(export_path).merge(self) + config = merged + else: + config = self + export_path.write_text( + EXPORT_TEMPLATE.format(schema_path=schema_path.name, config_yaml=config), + encoding="utf8", + ) + return export_path + + def rediscover(self, *, include_home_dir: bool = True, lookup_depth: int = 2) -> list[Path]: + """Rediscover config paths based on the ``SCRIPT_DIR`` discovered on import. + + Parameters + ---------- + include_home_dir : bool + Where or not to include the users home folder in the config lookup. Defaults to True + lookup_depth : int + Depth at which to look for configs in parent folders of ``script_dir``. + If set to ``1`` only ``script_dir`` will be considered as config dir. + Defaults to ``2``. + + Returns + ------- + list[Path] + Paths of the discovered config files. + """ + from pyglotaran_extras import SCRIPT_DIR + + self._source_files = list( + discover_config_files( + SCRIPT_DIR, include_home_dir=include_home_dir, lookup_depth=lookup_depth + ) + ) + return self._source_files + + def init_project(self) -> Config: + """Initialize configuration for the current project. + + This will use the configs discovered and resolved config during import to create a new + config and schema for your current project inside of your working directory (script dir), + if it didn't exist before. + + Returns + ------- + Config + """ + from pyglotaran_extras import SCRIPT_DIR + + if any(find_config_in_dir(SCRIPT_DIR)) is False: + self.export() + self.rediscover() + self.reload() + return self + + +def find_config_in_dir(dir_path: Path) -> Generator[Path, None, None]: + """Find the config file inside of dir ``dir_path``. + + Parameters + ---------- + dir_path : Path + Directory path to look for a config file. + + Yields + ------ + Path + """ + for extension in (".yaml", ".yml"): + config_file = (dir_path / CONFIG_FILE_STEM).with_suffix(extension) + if config_file.is_file(): + yield config_file + + +def discover_config_files( + script_dir: Path, *, include_home_dir: bool = True, lookup_depth: int = 2 +) -> Generator[Path, None, None]: + """Find config files in the users home folder and the current working dir and parents. + + Parameters + ---------- + script_dir : Path + Path to the current scripts/notebooks parent folder. + include_home_dir : bool + Where or not to include the users home folder in the config lookup. Defaults to True + lookup_depth : int + Depth at which to look for configs in parent folders of ``script_dir``. + If set to ``1`` only ``script_dir`` will be considered as config dir. + Defaults to ``2``. + + Yields + ------ + Path + """ + if include_home_dir is True: + yield from find_config_in_dir(Path.home()) + parent_dirs = tuple(reversed((script_dir / "dummy").parents)) + if lookup_depth > 0 and lookup_depth <= len(parent_dirs): + parent_dirs = parent_dirs[-lookup_depth:] + for parent in parent_dirs: + yield from find_config_in_dir(parent) + + +def load_config_files(config_paths: Iterable[Path]) -> Generator[Config, None, None]: + """Load config files into new config instances. + + Parameters + ---------- + config_paths : Iterable[Path] + Path to the config file. + + Yields + ------ + Config + """ + yaml = YAML() + for config_path in config_paths: + try: + config_dict = yaml.load(config_path) + config = Config.model_validate(config_dict) if config_dict is not None else Config() + config._source_files.append(config_path) + yield config + # We use a very broad range of exception to ensure the config loading at import never + # breaks importing + except Exception as error: # noqa: BLE001 + print( # noqa: T201 + "Error loading the config:\n", + f"Source path: {config_path.as_posix()}\n", + f"Error: {error}", + file=sys.stderr, + sep="", + ) + + +def merge_configs(configs: Iterable[Config]) -> Config: + """Merge ``Config``'s from left to right, where the right ``Config`` overrides the left. + + Parameters + ---------- + configs : Iterable[Config] + Config instances to merge together. + + Returns + ------- + Config + """ + full_config = Config() + for config in configs: + full_config = full_config.merge(config) + return full_config + + +def load_config( + script_dir: Path, *, include_home_dir: bool = True, lookup_depth: int = 2 +) -> Config: + """Discover and load config files. + + Parameters + ---------- + script_dir : Path + Path to the current scripts/notebooks parent folder. + include_home_dir : bool + Where or not to include the users home folder in the config lookup. Defaults to True + lookup_depth : int + Depth at which to look for configs in parent folders of ``script_dir``. + If set to ``1`` only ``script_dir`` will be considered as config dir. + Defaults to ``2``. + + Returns + ------- + Config + + See Also + -------- + discover_config_files + """ + config_paths = discover_config_files( + script_dir, include_home_dir=include_home_dir, lookup_depth=lookup_depth + ) + configs = load_config_files(config_paths) + return merge_configs(configs) + + +def _find_script_dir_at_import(package_root_file: str) -> Path: + """Find the script dir when importing ``pyglotaran_extras``. + + The assumption is that the first file not inside of ``pyglotaran_extras`` or importlib + is the script in question. + The max ``nesting_offset`` of 20 was chosen semi arbitrarily (typically ``nesting + offset`` + is around 9-13 depending on the import) to ensure that there won't be an infinite loop. + + Parameters + ---------- + package_root_file : str + The dunder file attribute (``__file__``) in the package root file. + + Returns + ------- + Path + """ + nesting_offset = 0 + importlib_path = Path(importlib.__file__).parent + package_root = Path(package_root_file).parent + script_dir = get_script_dir(nesting=2) + while ( + importlib_path in (script_dir / "dummy").parents + or package_root in (script_dir / "dummy").parents + ) and nesting_offset < 20: + nesting_offset += 1 + script_dir = get_script_dir(nesting=2 + nesting_offset) + return script_dir + + +def create_config_schema( + output_folder: Path | str | None = None, + file_name: Path | str = f"{CONFIG_FILE_STEM}.schema.json", +) -> Path: + """Create json schema file to be used for autocompletion and linting of the config. + + Parameters + ---------- + output_folder : Path | str | None + Folder to write schema file to. Defaults to None, which means that the script + folder is used + file_name : Path | str + Name of the scheme file. Defaults to "pygta_config.schema.json" + + Returns + ------- + Path + Path to the file the schema got saved to. + + Raises + ------ + UsePlotConfigError + If any function decorated with ``use_plot_config`` has a keyword argument with a default + value and a type annotation that can not be serialized into a json schema. + """ + json_schema = Config.model_json_schema() + general_kwargs: dict[str, Any] = {} + + for function_name, default_kwargs in __PlotFunctionRegistry.items(): + try: + name_prefix = "".join([parts.capitalize() for parts in function_name.split("_")]) + fields: Any = { + kwarg_name: ( + kwarg_value["annotation"], + FieldInfo( + default=kwarg_value["default"], description=kwarg_value["docstring"] + ), + ) + for kwarg_name, kwarg_value in default_kwargs.items() + } + kwargs_model_name = f"{name_prefix}Kwargs" + func_kwargs = create_model( + kwargs_model_name, + __config__=ConfigDict(extra="forbid"), + __doc__=( + f"Default arguments to use for ``{function_name}``, " + "if not specified in function call." + ), + **fields, + ) + config_model_name = f"{name_prefix}Config" + func_config = create_model( + config_model_name, + __config__=ConfigDict(extra="forbid"), + __doc__=( + f"Plot function configuration specific to ``{function_name}`` " + "(overrides values in general)." + ), + default_args_override=(func_kwargs, {}), + axis_label_override=(PlotLabelOverrideMap, PlotLabelOverrideMap()), + ) + func_json_schema = func_config.model_json_schema() + general_kwargs |= func_json_schema["$defs"][kwargs_model_name]["properties"] + json_schema["$defs"] |= func_json_schema.pop("$defs") + json_schema["$defs"][config_model_name] = func_json_schema + json_schema["$defs"]["PlotConfig"]["properties"][function_name] = { + "allOf": [{"$ref": f"#/$defs/{config_model_name}"}] + } + except PydanticUserError as error: + raise UsePlotConfigError(function_name, error) # noqa: B904 + json_schema["$defs"]["PerFunctionPlotConfig"]["properties"]["default_args_override"][ + "properties" + ] = general_kwargs + json_schema["$defs"]["PerFunctionPlotConfig"]["properties"]["default_args_override"][ + "additionalProperties" + ] = False + if output_folder is None: + from pyglotaran_extras import SCRIPT_DIR + + output_folder = SCRIPT_DIR + else: + output_folder = Path(output_folder) + output_folder.mkdir(parents=True, exist_ok=True) + output_file = output_folder / file_name + output_file.write_text(json.dumps(json_schema, ensure_ascii=False), encoding="utf8") + return output_file diff --git a/pyglotaran_extras/config/plot_config.py b/pyglotaran_extras/config/plot_config.py new file mode 100644 index 00000000..21aebb39 --- /dev/null +++ b/pyglotaran_extras/config/plot_config.py @@ -0,0 +1,559 @@ +"""Module containing plot configuration.""" + +from __future__ import annotations + +from collections.abc import Generator +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from collections.abc import MutableMapping +from contextlib import contextmanager +from functools import wraps +from inspect import Parameter +from inspect import getcallargs +from inspect import signature +from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import TypeAlias +from typing import TypedDict +from typing import cast + +import numpy as np +from docstring_parser import parse as parse_docstring +from matplotlib.axes import Axes +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import RootModel +from pydantic import ValidationError +from pydantic import field_validator +from pydantic import model_serializer +from pydantic import model_validator +from pydantic_core import ErrorDetails +from pydantic_core import PydanticUndefined + +from pyglotaran_extras.config.utils import add_yaml_repr + +if TYPE_CHECKING: + from collections.abc import Callable + + from pyglotaran_extras.config.config import Config + from pyglotaran_extras.types import Param + from pyglotaran_extras.types import RetType + + +class DefaultKwarg(TypedDict): + """Default value and type annotation of a kwarg extracted from the function signature.""" + + default: Any + annotation: str + docstring: str | None + + +DefaultKwargs: TypeAlias = Mapping[str, DefaultKwarg] +__PlotFunctionRegistry: MutableMapping[str, DefaultKwargs] = {} + + +@add_yaml_repr +class PlotLabelOverrideValue(BaseModel): + """Value of ``PlotLabelOverrideMap``.""" + + model_config = ConfigDict(extra="forbid") + + target_name: str + axis: Literal["x", "y", "both"] = "both" + + @model_serializer + def serialize(self) -> dict[str, Any] | str: + """Serialize supporting short notation. + + Returns + ------- + dict[str, Any] | str + """ + if self.axis == "both": + return self.target_name + return {"target_name": self.target_name, "axis": self.axis} + + +def _add_short_notation_to_schema(json_schema: dict[str, Any]) -> None: # noqa: DOC + """Update json schema to support short notation for ``PlotLabelOverrideValue``.""" + orig_additional_properties = json_schema["additionalProperties"] + json_schema["additionalProperties"] = { + "anyOf": [orig_additional_properties, {"type": "string"}] + } + + +@add_yaml_repr +class PlotLabelOverrideMap(RootModel, Mapping): + """Mapping to override axis labels.""" + + model_config = ConfigDict(json_schema_extra=_add_short_notation_to_schema) + + root: dict[str, PlotLabelOverrideValue] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def parse(cls, values: dict[str, Any]) -> dict[str, PlotLabelOverrideValue]: # noqa: DOC + """Parse ``axis_label_override`` dictionary supporting verbose and short notation. + + Parameters + ---------- + values : dict[str, Any] + Dict that initializes the class. + + Returns + ------- + dict[str, PlotLabelOverrideValue] + """ + if values is PydanticUndefined or values is None: + return {} + errors: dict[str, ErrorDetails] = {} + parsed_values: dict[str, PlotLabelOverrideValue] = {} + for key, value in values.items(): + try: + if isinstance(value, str): + parsed_values[key] = PlotLabelOverrideValue(target_name=value) + else: + parsed_values[key] = PlotLabelOverrideValue.model_validate(value) + except ValidationError as error: + errors |= {str(e): e for e in error.errors()} + if len(errors) > 0: + raise ValidationError.from_exception_data(cls.__name__, line_errors=[*errors.values()]) # type:ignore[list-item] + return parsed_values + + def __iter__(self) -> Iterator[str]: # type:ignore[override] # noqa: DOC + """Iterate over items.""" + return iter(self.root) + + def __len__(self) -> int: # noqa: DOC + """Get number of items.""" + return len(self.root) + + def __getitem__(self, item_label: str) -> PlotLabelOverrideValue: # noqa: DOC + """Access items.""" + return self.root[item_label] + + def __contains__(self, item_label: object) -> bool: # noqa: DOC + """Check if item is ``in`` the object.""" + return item_label in self.root + + def find_axis_label(self, matplotlib_label: str, axis_name: Literal["x", "y"]) -> str | None: + """Find axis label even if ``matplotlib`` or the user added a newline in it. + + Parameters + ---------- + matplotlib_label : str + Label extracted from the ``matplotlib`` ``Axes`` with ``ax.get_xlabel()`` or + ``ax.get_xlabel()``. + axis_name : Literal["x", "y"] + Name of the axis to find the label for. + + Returns + ------- + str | None + Mapped label value if found and None otherwise. + """ + if matplotlib_label in self and self[matplotlib_label].axis in (axis_name, "both"): + return self[matplotlib_label].target_name + + # If a label is too long to fit matplotlib inserts a newline which means we can not look it + # up with string equality + for key, value in self.root.items(): + if matplotlib_label.replace("\n", "") == key.replace("\n", "") and value.axis in ( + axis_name, + "both", + ): + return value.target_name + return None + + +@add_yaml_repr +class PerFunctionPlotConfig(BaseModel): + """Per function plot configuration.""" + + model_config = ConfigDict(extra="forbid") + + default_args_override: dict[str, Any] = Field( + default_factory=dict, + description="Default arguments to use if not specified in function call.", + ) + axis_label_override: PlotLabelOverrideMap | dict[str, str] = Field( + default_factory=PlotLabelOverrideMap + ) + + @field_validator("axis_label_override", mode="before") + @classmethod + def validate_axis_label_override( # noqa: DOC + cls, value: PlotLabelOverrideMap | dict[str, str] + ) -> PlotLabelOverrideMap: + """Ensure that ``axis_label_override`` gets converted into ``PlotLabelOverrideMap``.""" + return PlotLabelOverrideMap.model_validate(value) + + @model_serializer + def serialize(self) -> dict[str, Any]: + """Serialize in a sparse manner leaving out empty values. + + Returns + ------- + dict[str, Any] + """ + serialized = {} + if len(self.default_args_override) > 0: + serialized["default_args_override"] = self.default_args_override + if len(self.axis_label_override) > 0: + serialized["axis_label_override"] = cast( + PlotLabelOverrideMap, self.axis_label_override + ).model_dump() + return serialized + + def merge(self, other: PerFunctionPlotConfig) -> PerFunctionPlotConfig: + """Merge two ``PerFunctionPlotConfig``'s where ``other`` overrides values. + + Parameters + ---------- + other : PerFunctionPlotConfig + Other ``PerFunctionPlotConfig`` to merge in. + + Returns + ------- + PerFunctionPlotConfig + """ + self_dict = self.model_dump() + other_dict = other.model_dump() + return PerFunctionPlotConfig.model_validate( + { + "default_args_override": ( + self_dict.pop("default_args_override", {}) + | other_dict.pop("default_args_override", {}) + ), + "axis_label_override": ( + self_dict.pop("axis_label_override", {}) + | other_dict.pop("axis_label_override", {}) + ), + } + ) + + def find_override_kwargs(self, not_user_provided_kwargs: set[str]) -> dict[str, Any]: + """Config key word arguments that were not provided by the user and are safe to override. + + Parameters + ---------- + not_user_provided_kwargs : set[str] + Set of keyword arguments that were provided by the user and thus should not be + overridden. + + Returns + ------- + dict[str, Any] + """ + return { + k: self.default_args_override[k] + for k in self.default_args_override + if k in not_user_provided_kwargs + } + + def update_axes_labels(self, axes: Axes | Iterable[Axes]) -> None: + """Apply label overrides to ``axes``. + + Parameters + ---------- + axes : Axes | Iterable[Axes] + Axes to apply the override to. + """ + if isinstance(axes, Axes): + self.update_axes_labels((axes,)) + return + for ax in axes: + if isinstance(ax, Axes): + orig_x_label = ax.get_xlabel() + orig_y_label = ax.get_ylabel() + axis_label_override = cast(PlotLabelOverrideMap, self.axis_label_override) + + if ( + override_label := axis_label_override.find_axis_label(orig_x_label, "x") + ) is not None: + ax.set_xlabel(override_label) + + if ( + override_label := axis_label_override.find_axis_label(orig_y_label, "y") + ) is not None: + ax.set_ylabel(override_label) + + elif isinstance(ax, np.ndarray): + self.update_axes_labels(ax.flatten()) + else: + self.update_axes_labels(ax) + + +@add_yaml_repr +class PlotConfig(BaseModel): + """Config for plot functions including default args and label overrides.""" + + model_config = ConfigDict(extra="allow") + + general: PerFunctionPlotConfig = Field( + default_factory=PerFunctionPlotConfig, + description="Config that gets applied to all functions if not specified otherwise.", + ) + + @model_validator(mode="before") + @classmethod + def parse(cls, values: dict[str, Any]) -> dict[str, PerFunctionPlotConfig]: + """Ensure the extra values are converted to ``PerFunctionPlotConfig``. + + Parameters + ---------- + values : dict[str, Any] + Dict that initializes the class. + + Returns + ------- + dict[str, PerFunctionPlotConfig] + + Raises + ------ + ValidationError + """ + parsed_values = {} + errors: dict[str, ErrorDetails] = {} + for key, value in values.items(): + try: + parsed_values[key] = PerFunctionPlotConfig.model_validate(value) + except ValidationError as error: + errors |= {str(e): {**e, "loc": (key, *e["loc"])} for e in error.errors()} + if len(errors) > 0: + raise ValidationError.from_exception_data(cls.__name__, line_errors=[*errors.values()]) # type:ignore[list-item] + return parsed_values + + def get_function_config(self, function_name: str) -> PerFunctionPlotConfig: + """Get config for a specific function. + + Parameters + ---------- + function_name : str + Name of the function to get the config for. + + Returns + ------- + PerFunctionPlotConfig + """ + function_config = self.general + if self.model_extra is not None and function_name in self.model_extra: + function_config = function_config.merge(self.model_extra[function_name]) + if hasattr(self, "__context_config"): + function_config = function_config.merge(getattr(self, "__context_config")) + return function_config + + def merge(self, other: PlotConfig) -> PlotConfig: # noqa: C901 + """Merge two ``PlotConfig``'s where ``other`` overrides values. + + Parameters + ---------- + other : PlotConfig + Other ``PlotConfig`` to merge in. + + Returns + ------- + PlotConfig + """ + updated: dict[str, PerFunctionPlotConfig] = {} + # Update general field + for key in self.model_fields_set: + updated[key] = cast(PerFunctionPlotConfig, getattr(self, key)) + if key in other.model_fields_set: + updated[key] = updated[key].merge(cast(PerFunctionPlotConfig, getattr(other, key))) + for key in other.model_fields_set: + if key not in updated: + updated[key] = getattr(other, key) + # Update model_extra + if self.model_extra is not None: + for key, value in self.model_extra.items(): + updated[key] = cast(PerFunctionPlotConfig, value) + if other.model_extra is not None and key in other.model_extra: + updated[key] = updated[key].merge( + cast(PerFunctionPlotConfig, other.model_extra[key]) + ) + if other.model_extra is not None: + for key, value in other.model_extra.items(): + if key not in updated: + updated[key] = value + + return PlotConfig.model_validate(updated) + + +def create_parameter_docstring_mapping(func: Callable[..., Any]) -> Mapping[str, str]: + """Create a mapping of parameter names and they docstrings. + + Parameters + ---------- + func : Callable[..., Any] + Function to create the parameter docstring mapping for. + + Returns + ------- + Mapping[str, str] + """ + param_docstring_mapping = {} + for param in parse_docstring(func.__doc__ if func.__doc__ is not None else "").params: + if param.description is not None: + param_docstring_mapping[param.arg_name] = " ".join(param.description.splitlines()) + return param_docstring_mapping + + +def extract_default_kwargs( + func: Callable[..., Any], exclude_kwargs: tuple[str, ...] +) -> DefaultKwargs: + """Extract the default kwargs of ``func`` from its signature. + + Parameters + ---------- + func : Callable[..., Any] + Function to extract the default args from. + exclude_kwargs : tuple[str, ...] + Names of keyword arguments that should be excluded. + + Returns + ------- + DefaultKwargs + + See Also + -------- + use_plot_config + """ + sig = signature(func) + param_docstring_mapping = create_parameter_docstring_mapping(func) + return { + k: { + "default": v.default, + "annotation": v.annotation if v.annotation is not Parameter.empty else "object", + "docstring": param_docstring_mapping.get(k, None), + } + for k, v in sig.parameters.items() + if k not in exclude_kwargs + and v.default is not Parameter.empty + and v.kind is not Parameter.POSITIONAL_ONLY + } + + +def find_not_user_provided_kwargs( + default_kwargs: DefaultKwargs, arg_names: Iterable[str], kwargs: Mapping[str, Any] +) -> set[str]: + """Find which kwargs of a function were not provided by the user. + + Those kwargs can be overridden by config value. + + Parameters + ---------- + default_kwargs : DefaultKwargs + Default keyword arguments to the function. + arg_names : Iterable[str] + Names of the positional arguments passed when calling the function. + kwargs : Mapping[str, Any] + Kwargs passed when calling the function. + + Returns + ------- + set[str] + + See Also + -------- + extract_default_kwargs + """ + return {k for k in default_kwargs if k not in kwargs and k not in arg_names} + + +def find_axes( + values: Iterable[Any], +) -> Generator[Axes, None, None]: + """Iterate over values and yield the values that are ``Axes``. + + Parameters + ---------- + values : Iterable[Any] + Values to look for an ``Axes`` values in. + + Yields + ------ + Axes + """ + for value in values: + if isinstance(value, str): + continue + elif isinstance(value, Axes): + yield value + elif isinstance(value, np.ndarray): + yield from find_axes(value.flatten()) + elif isinstance(value, Iterable): + yield from find_axes(value) + + +def use_plot_config( # noqa: DOC201, DOC203 + exclude_from_config: tuple[str, ...] = (), +) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: + """Decorate plot functions to register it and enables auto use of config. + + Parameters + ---------- + exclude_from_config : tuple[str, ...] + Names of keyword argument with default for which the type can not be represent in the + config. Defaults to () + """ + + def outer_wrapper(func: Callable[Param, RetType]) -> Callable[Param, RetType]: # noqa: DOC + """Outer wrapper to allow for ``ignore_kwargs`` to be passed.""" + default_kwargs = extract_default_kwargs(func, exclude_from_config) + __PlotFunctionRegistry[func.__name__] = default_kwargs + + @wraps(func) + def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType: # noqa: DOC + """Wrap function and apply config.""" + from pyglotaran_extras import CONFIG + + CONFIG.reload() + + arg_names = func.__code__.co_varnames[: len(args)] + not_user_provided_kwargs = find_not_user_provided_kwargs( + default_kwargs, arg_names, kwargs + ) + function_config = CONFIG.plotting.get_function_config(func.__name__) + override_kwargs = function_config.find_override_kwargs(not_user_provided_kwargs) + updated_kwargs = kwargs | override_kwargs + arg_axes = find_axes(getcallargs(func, *args, **updated_kwargs).values()) + return_values = func(*args, **updated_kwargs) + function_config.update_axes_labels(arg_axes) + + if isinstance(return_values, Iterable): + return_axes = find_axes(return_values) + function_config.update_axes_labels(return_axes) + + return return_values + + return wrapper + + return outer_wrapper + + +@contextmanager +def plot_config_context(plot_config: PerFunctionPlotConfig) -> Generator[Config, None, None]: + """Context manager to override parts of the resolved functions ``PlotConfig``. + + Parameters + ---------- + plot_config : PerFunctionPlotConfig + Function plot config override to update plot config for functions run inside of context. + + Yields + ------ + Config + """ + from pyglotaran_extras import CONFIG + + setattr( + CONFIG.plotting, + "__context_config", + PerFunctionPlotConfig.model_validate(plot_config), + ) + yield CONFIG + delattr(CONFIG.plotting, "__context_config") diff --git a/pyglotaran_extras/config/utils.py b/pyglotaran_extras/config/utils.py new file mode 100644 index 00000000..36b0b565 --- /dev/null +++ b/pyglotaran_extras/config/utils.py @@ -0,0 +1,53 @@ +"""Module containing config utilities.""" + +from __future__ import annotations + +from io import StringIO +from typing import TYPE_CHECKING + +from ruamel.yaml import YAML + +if TYPE_CHECKING: + from pyglotaran_extras.types import SupportsModelDump + + +def to_yaml_str(self: SupportsModelDump) -> str: + """Create yaml string from dumped model. + + Parameters + ---------- + self : SupportsModelDump + Instance of a class that supports ``model_dump``. + + Returns + ------- + str + + See Also + -------- + add_yaml_repr + """ + yaml = YAML() + yaml.indent(mapping=2, sequence=4, offset=2) + buffer = StringIO() + yaml.dump(self.model_dump(), buffer) + buffer.seek(0) + return buffer.read() + + +def add_yaml_repr(cls: type[SupportsModelDump]) -> type[SupportsModelDump]: + """Add yaml ``__str__`` and ``_repr_markdown_`` methods to class that supports ``model_dump``. + + Parameters + ---------- + cls : type[SupportsModelDump] + Class to add the methods to. + + Returns + ------- + type[SupportsModelDump] + """ + cls.__str__ = to_yaml_str # type:ignore[method-assign] + cls._repr_markdown_ = lambda self: f"```yaml\n{self}\n```" # type:ignore[attr-defined] + + return cls diff --git a/pyglotaran_extras/deprecation/deprecation_utils.py b/pyglotaran_extras/deprecation/deprecation_utils.py index 818129f0..06479de6 100644 --- a/pyglotaran_extras/deprecation/deprecation_utils.py +++ b/pyglotaran_extras/deprecation/deprecation_utils.py @@ -11,7 +11,7 @@ ) -class OverDueDeprecationError(Exception): +class OverdueDeprecationError(Exception): """Error thrown when a deprecation should have been removed. See Also @@ -96,7 +96,7 @@ def check_overdue(deprecated_qual_name_usage: str, to_be_removed_in_version: str Raises ------ - OverDueDeprecation + OverdueDeprecation If the current version is greater or equal to ``to_be_removed_in_version``. """ if ( @@ -108,7 +108,7 @@ def check_overdue(deprecated_qual_name_usage: str, to_be_removed_in_version: str f"was supposed to be dropped in version: {to_be_removed_in_version!r}.\n" f"Current version is: {pyglotaran_extras_version()!r}" ) - raise OverDueDeprecationError(msg) + raise OverdueDeprecationError(msg) def warn_deprecated( @@ -138,7 +138,7 @@ def warn_deprecated( Raises ------ - OverDueDeprecation + OverdueDeprecation If the current version is greater or equal to ``to_be_removed_in_version``. """ check_overdue(deprecated_qual_name_usage, to_be_removed_in_version) diff --git a/pyglotaran_extras/plotting/plot_coherent_artifact.py b/pyglotaran_extras/plotting/plot_coherent_artifact.py index d65bceae..dcc9fa02 100644 --- a/pyglotaran_extras/plotting/plot_coherent_artifact.py +++ b/pyglotaran_extras/plotting/plot_coherent_artifact.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none @@ -23,6 +24,7 @@ from pyglotaran_extras.types import DatasetConvertible +@use_plot_config(exclude_from_config=("cycler",)) def plot_coherent_artifact( dataset: DatasetConvertible | Result, *, diff --git a/pyglotaran_extras/plotting/plot_concentrations.py b/pyglotaran_extras/plotting/plot_concentrations.py index 319df8f5..dfa2816e 100644 --- a/pyglotaran_extras/plotting/plot_concentrations.py +++ b/pyglotaran_extras/plotting/plot_concentrations.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import MinorSymLogLocator from pyglotaran_extras.plotting.utils import add_cycler_if_not_none @@ -15,6 +16,7 @@ from matplotlib.axis import Axis +@use_plot_config(exclude_from_config=("cycler",)) def plot_concentrations( res: xr.Dataset, ax: Axis, diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index 7bc3939e..40072f1c 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -9,6 +9,7 @@ from glotaran.io.prepare_dataset import add_svd_to_dataset from matplotlib.axis import Axis +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.plot_svd import plot_lsv_data from pyglotaran_extras.plotting.plot_svd import plot_rsv_data @@ -32,6 +33,7 @@ from pyglotaran_extras.types import DatasetConvertible +@use_plot_config(exclude_from_config=("svd_cycler",)) def plot_data_overview( dataset: DatasetConvertible | Result, title: str = "Data overview", diff --git a/pyglotaran_extras/plotting/plot_doas.py b/pyglotaran_extras/plotting/plot_doas.py index fa80d998..5fcc3e55 100644 --- a/pyglotaran_extras/plotting/plot_doas.py +++ b/pyglotaran_extras/plotting/plot_doas.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import abs_max @@ -25,6 +26,7 @@ from pyglotaran_extras.types import DatasetConvertible +@use_plot_config(exclude_from_config=("cycler",)) def plot_doas( dataset: DatasetConvertible | Result, *, diff --git a/pyglotaran_extras/plotting/plot_guidance.py b/pyglotaran_extras/plotting/plot_guidance.py index b3ace952..b33f5e41 100644 --- a/pyglotaran_extras/plotting/plot_guidance.py +++ b/pyglotaran_extras/plotting/plot_guidance.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import add_cycler_if_not_none @@ -19,6 +20,7 @@ from pyglotaran_extras.types import DatasetConvertible +@use_plot_config(exclude_from_config=("cycler",)) def plot_guidance( result: DatasetConvertible | Result, figsize: tuple[float, float] = (15, 5), diff --git a/pyglotaran_extras/plotting/plot_irf_dispersion_center.py b/pyglotaran_extras/plotting/plot_irf_dispersion_center.py index 0a570df2..7f015e0d 100644 --- a/pyglotaran_extras/plotting/plot_irf_dispersion_center.py +++ b/pyglotaran_extras/plotting/plot_irf_dispersion_center.py @@ -10,6 +10,7 @@ from matplotlib.axis import Axis from matplotlib.figure import Figure +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.utils import result_dataset_mapping from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import add_cycler_if_not_none @@ -23,6 +24,7 @@ from pyglotaran_extras.types import ResultLike +@use_plot_config(exclude_from_config=("cycler", "ax")) def plot_irf_dispersion_center( result: ResultLike, ax: Axis | None = None, diff --git a/pyglotaran_extras/plotting/plot_overview.py b/pyglotaran_extras/plotting/plot_overview.py index 1a489069..af12042c 100644 --- a/pyglotaran_extras/plotting/plot_overview.py +++ b/pyglotaran_extras/plotting/plot_overview.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.deprecation.deprecation_utils import FIG_ONLY_WARNING from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning from pyglotaran_extras.io.load_data import load_data @@ -34,6 +35,7 @@ from pyglotaran_extras.types import UnsetType +@use_plot_config(exclude_from_config=("cycler", "das_cycler", "svd_cycler")) def plot_overview( result: DatasetConvertible | Result, center_λ: float | None = None, @@ -176,6 +178,7 @@ def plot_overview( return fig, axes +@use_plot_config(exclude_from_config=("cycler", "svd_cycler")) def plot_simple_overview( result: DatasetConvertible | Result, title: str | None = None, diff --git a/pyglotaran_extras/plotting/plot_pfid.py b/pyglotaran_extras/plotting/plot_pfid.py index 22e40bad..0cca81a9 100644 --- a/pyglotaran_extras/plotting/plot_pfid.py +++ b/pyglotaran_extras/plotting/plot_pfid.py @@ -168,12 +168,6 @@ def plot_pfid( # noqa: C901 axes[0, 0].set_title(f"Cos Oscillations {spectral}") axes[1, 0].set_title(f"Sin Oscillations {spectral}") - axes[0, 0].set_xlabel("Time (ps)") - axes[1, 0].set_xlabel("Time (ps)") - axes[0, 1].set_xlabel("Wavenumber (1/cm)") - axes[1, 1].set_xlabel("Wavenumber (1/cm)") - axes[0, 2].set_xlabel("Wavenumber (1/cm)") - axes[1, 2].set_xlabel("Wavenumber (1/cm)") else: axes[0].set_title(f"{oscillation_type.capitalize()} Oscillations {spectral}") diff --git a/pyglotaran_extras/plotting/plot_residual.py b/pyglotaran_extras/plotting/plot_residual.py index 2e88feb2..afed8cf3 100644 --- a/pyglotaran_extras/plotting/plot_residual.py +++ b/pyglotaran_extras/plotting/plot_residual.py @@ -6,6 +6,7 @@ import numpy as np +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.plotting.plot_irf_dispersion_center import _plot_irf_dispersion_center from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import MinorSymLogLocator @@ -18,6 +19,7 @@ from matplotlib.axis import Axis +@use_plot_config(exclude_from_config=("cycler",)) def plot_residual( res: xr.Dataset, ax: Axis, diff --git a/pyglotaran_extras/plotting/plot_spectra.py b/pyglotaran_extras/plotting/plot_spectra.py index 76f0d6de..b55b43ac 100644 --- a/pyglotaran_extras/plotting/plot_spectra.py +++ b/pyglotaran_extras/plotting/plot_spectra.py @@ -6,6 +6,7 @@ import numpy as np +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import add_cycler_if_not_none from pyglotaran_extras.types import Unset @@ -19,6 +20,7 @@ from pyglotaran_extras.types import UnsetType +@use_plot_config(exclude_from_config=("cycler", "das_cycler")) def plot_spectra( res: xr.Dataset, axes: Axes, @@ -51,6 +53,7 @@ def plot_spectra( plot_norm_das(res, axes[1, 1], cycler=das_cycler, show_zero_line=show_zero_line) +@use_plot_config(exclude_from_config=("cycler",)) def plot_sas( res: xr.Dataset, ax: Axis, @@ -86,6 +89,7 @@ def plot_sas( ax.axhline(0, color="k", linewidth=1) +@use_plot_config(exclude_from_config=("cycler",)) def plot_norm_sas( res: xr.Dataset, ax: Axis, @@ -123,6 +127,7 @@ def plot_norm_sas( ax.axhline(0, color="k", linewidth=1) +@use_plot_config(exclude_from_config=("cycler",)) def plot_das( res: xr.Dataset, ax: Axis, @@ -158,6 +163,7 @@ def plot_das( ax.axhline(0, color="k", linewidth=1) +@use_plot_config(exclude_from_config=("cycler",)) def plot_norm_das( res: xr.Dataset, ax: Axis, diff --git a/pyglotaran_extras/plotting/plot_svd.py b/pyglotaran_extras/plotting/plot_svd.py index 4ea639e8..ada8a441 100644 --- a/pyglotaran_extras/plotting/plot_svd.py +++ b/pyglotaran_extras/plotting/plot_svd.py @@ -6,6 +6,7 @@ from glotaran.io.prepare_dataset import add_svd_to_dataset +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.deprecation import warn_deprecated from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import MinorSymLogLocator @@ -23,6 +24,7 @@ from matplotlib.pyplot import Axes +@use_plot_config(exclude_from_config=("cycler",)) def plot_svd( res: xr.Dataset, axes: Axes, @@ -115,10 +117,11 @@ def plot_svd( plot_sv_data(res, axes[1, 2], use_svd_number=use_svd_number) +@use_plot_config(exclude_from_config=("cycler",)) def plot_lsv_data( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(4), + indices: Sequence[int] = tuple(range(4)), linlog: bool = False, linthresh: float = 1, cycler: Cycler | None = PlotStyle().cycler, @@ -135,7 +138,7 @@ def plot_lsv_data( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(4). + Indices of the singular vector to plot. Defaults to tuple(range(4)). linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float @@ -169,10 +172,11 @@ def plot_lsv_data( ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) +@use_plot_config(exclude_from_config=("cycler",)) def plot_rsv_data( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(4), + indices: Sequence[int] = tuple(range(4)), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, irf_location: float | None = None, @@ -187,7 +191,7 @@ def plot_rsv_data( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(4). + Indices of the singular vector to plot. Defaults to tuple(range(4)). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend : bool @@ -213,10 +217,11 @@ def plot_rsv_data( ax.set_title("data. RSV") +@use_plot_config(exclude_from_config=("cycler",)) def plot_sv_data( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(10), + indices: Sequence[int] = tuple(range(10)), cycler: Cycler | None | UnsetType = Unset, use_svd_number: bool = False, ) -> None: @@ -229,7 +234,7 @@ def plot_sv_data( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(10). + Indices of the singular vector to plot. Defaults to tuple(range(10)). cycler : Cycler | None | UnsetType Deprecated since it has no effect. Defaults to Unset. use_svd_number : bool @@ -249,16 +254,17 @@ def plot_sv_data( dSV = dSV.assign_coords( # noqa: N806 {x_dim: ("singular_value_index", (dSV.singular_value_index + 1).data)} ) - dSV.sel({"singular_value_index": indices[: len(dSV.singular_value_index)]}).plot.line( + dSV.sel({"singular_value_index": list(indices[: len(dSV.singular_value_index)])}).plot.line( "ro-", yscale="log", ax=ax, x=x_dim ) ax.set_title("data. log(SV)") +@use_plot_config(exclude_from_config=("cycler",)) def plot_lsv_residual( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(2), + indices: Sequence[int] = tuple(range(2)), linlog: bool = False, linthresh: float = 1, cycler: Cycler | None = PlotStyle().cycler, @@ -275,7 +281,7 @@ def plot_lsv_residual( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(4). + Indices of the singular vector to plot. Defaults to tuple(range(4)). linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float @@ -314,10 +320,11 @@ def plot_lsv_residual( ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) +@use_plot_config(exclude_from_config=("cycler",)) def plot_rsv_residual( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(2), + indices: Sequence[int] = tuple(range(2)), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, irf_location: float | None = None, @@ -332,7 +339,7 @@ def plot_rsv_residual( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(4). + Indices of the singular vector to plot. Defaults to tuple(range(4)). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. show_legend : bool @@ -362,10 +369,11 @@ def plot_rsv_residual( ax.set_title("res. RSV") +@use_plot_config(exclude_from_config=("cycler",)) def plot_sv_residual( res: xr.Dataset, ax: Axis, - indices: Sequence[int] = range(10), + indices: Sequence[int] = tuple(range(10)), cycler: Cycler | None | UnsetType = Unset, use_svd_number: bool = False, ) -> None: @@ -378,7 +386,7 @@ def plot_sv_residual( ax : Axis Axis to plot on. indices : Sequence[int] - Indices of the singular vector to plot. Defaults to range(10). + Indices of the singular vector to plot. Defaults to tuple(range(10)). cycler : Cycler | None | UnsetType Deprecated since it has no effect. Defaults to Unset. use_svd_number : bool @@ -402,7 +410,7 @@ def plot_sv_residual( rSV = rSV.assign_coords( # noqa: N806 {x_dim: ("singular_value_index", (rSV.singular_value_index + 1).data)} ) - rSV.sel({"singular_value_index": indices[: len(rSV.singular_value_index)]}).plot.line( + rSV.sel({"singular_value_index": list(indices[: len(rSV.singular_value_index)])}).plot.line( "ro-", yscale="log", ax=ax, x=x_dim ) ax.set_title("res. log(SV)") @@ -447,7 +455,9 @@ def _plot_svd_vectors( """ max_index = len(getattr(vector_data, sv_index_dim)) values = shift_time_axis_by_irf_location( - vector_data.isel({sv_index_dim: indices[:max_index]}), irf_location, _internal_call=True + vector_data.isel({sv_index_dim: list(indices[:max_index])}), + irf_location, + _internal_call=True, ) x_dim = vector_data.dims[1] if x_dim == sv_index_dim: diff --git a/pyglotaran_extras/plotting/plot_traces.py b/pyglotaran_extras/plotting/plot_traces.py index fa79f1ed..bb215c1e 100644 --- a/pyglotaran_extras/plotting/plot_traces.py +++ b/pyglotaran_extras/plotting/plot_traces.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt +from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.utils import result_dataset_mapping from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import MinorSymLogLocator @@ -31,6 +32,7 @@ from pyglotaran_extras.types import ResultLike +@use_plot_config(exclude_from_config=("cycler",)) def plot_data_and_fits( result: ResultLike, wavelength: float, @@ -108,6 +110,7 @@ def plot_data_and_fits( axis.legend() +@use_plot_config(exclude_from_config=("cycler",)) def plot_fitted_traces( result: ResultLike, wavelengths: Iterable[float], diff --git a/pyglotaran_extras/types.py b/pyglotaran_extras/types.py index cbf013d0..b7960cfb 100644 --- a/pyglotaran_extras/types.py +++ b/pyglotaran_extras/types.py @@ -6,12 +6,17 @@ from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING +from typing import Any from typing import Literal +from typing import ParamSpec from typing import TypeAlias from typing import TypedDict +from typing import TypeVar import xarray as xr from glotaran.project.result import Result +from pydantic import BaseModel +from pydantic import RootModel if TYPE_CHECKING: from pyglotaran_extras.plotting.style import ColorCode @@ -68,3 +73,9 @@ class CyclerColor(TypedDict): SubPlotLabelCoord: TypeAlias = ( SubPlotLabelCoordStrs | tuple[SubPlotLabelCoordStrs, SubPlotLabelCoordStrs] ) + + +Param = ParamSpec("Param") +RetType = TypeVar("RetType") + +SupportsModelDump = TypeVar("SupportsModelDump", bound=(BaseModel | RootModel[Any])) diff --git a/pyproject.toml b/pyproject.toml index 0c3073df..9a00508a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,9 +32,12 @@ dynamic = [ ] dependencies = [ "cycler>=0.10", + "docstring-parser>=0.16", "matplotlib>=3.3", "numpy>=1.22", + "pydantic>=2.7", "pyglotaran>=0.7", + "ruamel-yaml>=0.18.6", "tabulate>=0.8.9", "xarray>=2022.3", ] @@ -42,18 +45,21 @@ optional-dependencies.dev = [ "pyglotaran-extras[docs,test]", ] optional-dependencies.docs = [ + "autodoc-pydantic>=2.2", "jupyterlab>=3", + "myst-nb>=1.1.1", # notebook docs "myst-parser>=0.12", - "nbsphinx>=0.8.1", # notebook docs "numpydoc>=0.8", "sphinx>=3.2", "sphinx-copybutton>=0.3", "sphinx-last-updated-by-git>=0.3", "sphinx-rtd-theme>=1.2", "sphinxcontrib-jquery>=4.1", # Needed for the search to work Ref.: https://github.com/readthedocs/sphinx_rtd_theme/issues/1434 + "sphinxcontrib-mermaid>=0.9.2", ] optional-dependencies.test = [ "coverage[toml]", + "jsonschema>=4.22", "nbval>=0.9.6", "pluggy>=0.7", "pytest>=3.7.1", @@ -93,11 +99,13 @@ exclude = '^(docs/|tests?/)' require-return-section-when-returning-none = false allow-init-docstring = true +[tool.coverage.paths] +source = [ + "pyglotaran_extras", + "*/site-packages/pyglotaran_extras", +] [tool.coverage.run] branch = true -include = [ - 'pyglotaran_extras/*', -] omit = [ 'tests/*', # comment the above line if you want to see if all tests did run @@ -124,6 +132,10 @@ exclude_lines = [ ] [tool.mypy] +plugins = [ + "pydantic.mypy", +] +python_version = "3.10" ignore_missing_imports = true scripts_are_modules = true show_error_codes = true @@ -134,6 +146,7 @@ disallow_untyped_defs = true disallow_untyped_calls = true no_implicit_reexport = true warn_unused_configs = true +check_untyped_defs = true [[tool.mypy.overrides]] module = "tests.*" diff --git a/readthedocs.yml b/readthedocs.yml index 2e33e9df..9f29891d 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -6,12 +6,9 @@ sphinx: configuration: docs/conf.py build: - os: ubuntu-22.04 + os: ubuntu-lts-latest tools: - python: "3.10" + python: "mambaforge-latest" -python: - install: - - requirements: requirements_pinned.txt - - method: pip - path: .[docs] +conda: + environment: docs/environment.yml diff --git a/requirements_pinned.txt b/requirements_pinned.txt index 08736b47..4d286f7c 100644 --- a/requirements_pinned.txt +++ b/requirements_pinned.txt @@ -6,3 +6,6 @@ matplotlib==3.9.2 pyglotaran==0.7.3 tabulate==0.9.0 xarray==2024.7.0 +ruamel.yaml==0.18.6 +docstring_parser==0.16 +pydantic==2.7.2 diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..79e3bbb2 --- /dev/null +++ b/tests/config/__init__.py @@ -0,0 +1 @@ +"""Tests for ``pyglotaran_extras.config``.""" diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 00000000..2ed883c1 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,610 @@ +"""Tests for ``pyglotaran_extras.config.config``.""" + +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path +from shutil import copyfile +from typing import Any + +import pytest +from jsonschema import Draft202012Validator +from ruamel.yaml import YAML + +from pyglotaran_extras import create_config_schema +from pyglotaran_extras.config.config import CONFIG_FILE_STEM +from pyglotaran_extras.config.config import Config +from pyglotaran_extras.config.config import discover_config_files +from pyglotaran_extras.config.config import find_config_in_dir +from pyglotaran_extras.config.config import load_config +from pyglotaran_extras.config.config import load_config_files +from pyglotaran_extras.config.config import merge_configs +from pyglotaran_extras.config.plot_config import PlotConfig +from pyglotaran_extras.config.plot_config import use_plot_config +from tests import TEST_DATA +from tests.conftest import generator_is_exhausted + + +def test_config(test_config_values: dict[str, Any]): + """Empty and from values initialization works.""" + empty_config = Config() + + assert empty_config.plotting == PlotConfig() + + test_config = Config.model_validate(test_config_values) + test_plot_config = PlotConfig.model_validate(test_config_values["plotting"]) + + assert test_config.plotting == test_plot_config + + +def test_config_merge(tmp_path: Path): + """Merging creates the expected output. + + - Update fields that are present in original and update + - Keep fields that are not present in the update + - Add field that is only present in update + """ + original_values, update_values, expected_values = tuple( + YAML().load_all(TEST_DATA / "config/config_merge.yml") + ) + original = Config.model_validate(original_values) + update = Config.model_validate(update_values) + expected = Config.model_validate(expected_values) + + assert original.merge(update) == expected + + config_with_paths = Config() + config_with_paths._source_files = [tmp_path / "foo", tmp_path / "bar", tmp_path / "baz"] + [file.touch() for file in config_with_paths._source_files] # type:ignore[func-returns-value] + + update = Config() + update._source_files = [tmp_path / "foo"] + + assert config_with_paths.merge(update)._source_files == [ + tmp_path / "bar", + tmp_path / "baz", + tmp_path / "foo", + ] + + +def test_config_reset(tmp_path: Path, test_config_values: dict[str, Any]): + """All config values but the source files get reset when no other config is passed. + + If another config is passed all values (including source files) are reset. + """ + + test_config_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + test_config = Config.model_validate(test_config_values) + test_config._source_files = [test_config_path] + + test_config._reset() + + assert test_config.plotting == PlotConfig() + assert test_config._source_files == [test_config_path] + + other = Config.model_validate(test_config_values) + + test_config._reset(other) + + assert test_config.plotting == PlotConfig.model_validate(test_config_values["plotting"]) + assert test_config._source_files == [] + assert test_config == other + + +def test_config_reload(tmp_path: Path, test_config_values: dict[str, Any]): + """Config values get reloaded from file.""" + + test_config_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + test_config = Config.model_validate(test_config_values) + test_config._source_files = [test_config_path] + + test_config_values["plotting"]["general"]["default_args_override"]["will_update_arg"] = ( + "file got updated" + ) + + YAML().dump(test_config_values, test_config_path) + + expected_plot_config = test_config.plotting.model_copy(deep=True) + expected_plot_config.general.default_args_override["will_update_arg"] = "file got updated" + + test_config.reload() + + assert test_config.plotting.model_dump() == expected_plot_config.model_dump() + assert test_config._source_files == [test_config_path] + + +def test_config_load(tmp_path: Path, test_config_values: dict[str, Any], test_config_file: Path): + """Loading config replaces its values and and source files.""" + + test_config = Config() + test_config._source_files = [test_config_file] + test_config.reload() + + test_config_values["plotting"]["general"]["default_args_override"]["will_update_arg"] = ( + "from new file" + ) + + test_config_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + YAML().dump(test_config_values, test_config_path) + + expected_plot_config = test_config.plotting.model_copy(deep=True) + expected_plot_config.general.default_args_override["will_update_arg"] = "from new file" + + test_config.load(test_config_path) + + assert test_config.plotting.model_dump() == expected_plot_config.model_dump() + assert test_config._source_files == [test_config_path] + + +def test_config_export(tmp_path: Path, test_config_file: Path): + """Exporting the config gives the expected result.""" + config = Config().load(test_config_file) + + exported_config_path = config.export(tmp_path) + + assert exported_config_path.is_file() is True + assert exported_config_path.samefile(tmp_path / f"{CONFIG_FILE_STEM}.yml") is True + assert (tmp_path / f"{CONFIG_FILE_STEM}.schema.json").is_file() is True + + assert exported_config_path.read_text(encoding="utf8") == test_config_file.read_text( + encoding="utf8" + ) + + +def test_config_export_update(tmp_path: Path, test_config_file: Path): + """Update existing config by default and overwrite if update kwarg is ``False``.""" + config = Config().load(test_config_file) + existing_config_path = config.export(tmp_path) + + update_config = Config( + plotting=PlotConfig.model_validate( + {"test_func": {"axis_label_override": {"will_update_label": "new label"}}} + ) + ) + update_config.export(tmp_path) + # Ensure there is no comparison conflict due to different source file paths + config._source_files = [existing_config_path] + update_config._source_files = [existing_config_path] + + assert Config().load(existing_config_path) == config.merge(update_config) + + update_config.export(tmp_path, update=False) + + expected_config = Config().load(existing_config_path) + assert expected_config._source_hash != update_config._source_hash + + expected_config._source_hash = update_config._source_hash + assert expected_config == update_config + + +@pytest.mark.usefixtures("mock_config") +def test_config_rediscover(tmp_path: Path, mock_home: Path): + """Check that new files are picked up and filters work.""" + config = Config() + assert config._source_files == [] + + home_config_path = mock_home / f"{CONFIG_FILE_STEM}.yml" + home_config_path.touch() + expected_config_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + config_paths_default = config.rediscover() + + assert config_paths_default == [home_config_path, expected_config_path] + assert config_paths_default == config._source_files + + config_paths_no_home = config.rediscover(include_home_dir=False) + + assert config_paths_no_home == [expected_config_path] + assert config_paths_no_home == config._source_files + + project_config_path = tmp_path / f"project/{CONFIG_FILE_STEM}.yml" + project_config_path.touch() + + config_paths_with_project = config.rediscover() + + assert config_paths_with_project == [ + home_config_path, + expected_config_path, + project_config_path, + ] + assert config_paths_with_project == config._source_files + + config_paths_only_project = config.rediscover(include_home_dir=False, lookup_depth=1) + + assert config_paths_only_project == [project_config_path] + assert config_paths_only_project == config._source_files + + +@pytest.mark.usefixtures("mock_config") +def test_config_init_project(tmp_path: Path, mock_home: Path): + """New config and schema are created.""" + from pyglotaran_extras import CONFIG + from pyglotaran_extras import SCRIPT_DIR + + expected_config_file = SCRIPT_DIR / f"{CONFIG_FILE_STEM}.yml" + expected_schema_file = SCRIPT_DIR / f"{CONFIG_FILE_STEM}.schema.json" + + assert tmp_path in SCRIPT_DIR.parents + assert expected_config_file.is_file() is False + assert expected_schema_file.is_file() is False + + assert expected_config_file not in CONFIG._source_files + + initial_config = CONFIG.model_copy(deep=True) + CONFIG.init_project() + + assert expected_config_file.is_file() is True + assert expected_schema_file.is_file() is True + + assert expected_config_file in CONFIG._source_files + + reloaded_config = Config() + reloaded_config._source_files = [ + SCRIPT_DIR.parent / f"{CONFIG_FILE_STEM}.yml", + expected_config_file, + ] + reloaded_config.reload() + + assert CONFIG == reloaded_config # noqa: SIM300 + assert CONFIG != initial_config # noqa: SIM300 + + +def test_find_config_in_dir(tmp_path: Path): + """Find one or two config files if present.""" + assert len(list(find_config_in_dir(tmp_path))) == 0 + + yml_config_path = tmp_path / f"yml/{CONFIG_FILE_STEM}.yml" + yml_config_path.parent.mkdir() + yml_config_path.touch() + + assert next(find_config_in_dir(yml_config_path.parent)) == yml_config_path + + yaml_config_path = tmp_path / f"yaml/{CONFIG_FILE_STEM}.yaml" + yaml_config_path.parent.mkdir() + yaml_config_path.touch() + + assert next(find_config_in_dir(yaml_config_path.parent)) == yaml_config_path + + multi_config_dir = tmp_path / "multi" + multi_config_dir.mkdir() + (multi_config_dir / f"{CONFIG_FILE_STEM}.yml").touch() + (multi_config_dir / f"{CONFIG_FILE_STEM}.yaml").touch() + (multi_config_dir / f"{CONFIG_FILE_STEM}.json").touch() + (multi_config_dir / f"{CONFIG_FILE_STEM}1.yml").touch() + + assert len(list(find_config_in_dir(multi_config_dir))) == 2 + + +def test_discover_config_files(tmp_path: Path, mock_home: Path): + """Discover all config files in the correct order.""" + file_name = f"{CONFIG_FILE_STEM}.yml" + script_dir = tmp_path / "top_project/project/sub_project" + script_dir.mkdir(parents=True) + + home_config = mock_home / file_name + home_config.touch() + + top_project_config = script_dir.parent.parent / file_name + top_project_config.touch() + + project_config = script_dir.parent / file_name + project_config.touch() + + sub_project_config = script_dir / file_name + sub_project_config.touch() + + default_discovery = discover_config_files(script_dir) + + assert next(default_discovery) == home_config + assert next(default_discovery) == project_config + assert next(default_discovery) == sub_project_config + assert generator_is_exhausted(default_discovery) is True + + no_recursion_discovery = discover_config_files(script_dir, include_home_dir=False) + + assert next(no_recursion_discovery) == project_config + assert next(no_recursion_discovery) == sub_project_config + assert generator_is_exhausted(no_recursion_discovery) is True + + no_recursion_discovery = discover_config_files(script_dir, lookup_depth=1) + + assert next(no_recursion_discovery) == home_config + assert next(no_recursion_discovery) == sub_project_config + assert generator_is_exhausted(no_recursion_discovery) is True + + top_project_discovery = discover_config_files(script_dir, lookup_depth=3) + + assert next(top_project_discovery) == home_config + assert next(top_project_discovery) == top_project_config + assert next(top_project_discovery) == project_config + assert next(top_project_discovery) == sub_project_config + assert generator_is_exhausted(top_project_discovery) is True + + +def test_load_config_files(tmp_path: Path, test_config_file: Path): + """Read configs and add source path.""" + empty_config_file = tmp_path / f"{CONFIG_FILE_STEM}.yml" + empty_config_file.touch() + + empty_file_loaded = load_config_files([empty_config_file]) + + empty_config = next(empty_file_loaded) + + assert empty_config.model_dump() == Config().model_dump() + assert empty_config._source_files == [empty_config_file] + assert generator_is_exhausted(empty_file_loaded) is True + + two_configs = load_config_files([empty_config_file, test_config_file]) + + empty_config = next(two_configs) + + assert empty_config.model_dump() == Config().model_dump() + assert empty_config._source_files == [empty_config_file] + + test_config_values = YAML().load(test_config_file) + + expected_config = Config.model_validate(test_config_values) + + test_config = next(two_configs) + + assert test_config.model_dump() == expected_config.model_dump() + assert test_config._source_files == [test_config_file] + assert generator_is_exhausted(two_configs) is True + + +def test_merge_configs(): + """Check that the most right config overrides other values. + + Since we already tested all permutations for 2 configs in ``test_config_merge`` + this test can fucus on the case with more than 2 configs. + """ + assert merge_configs([]) == Config() + + original_values, update_values, expected_values = tuple( + YAML().load_all(TEST_DATA / "config/config_merge.yml") + ) + additional_update_values = { + "plotting": { + "general": {"default_args_override": {"will_update_arg": "additional update"}} + } + } + expected_values["plotting"]["general"]["default_args_override"]["will_update_arg"] = ( + "additional update" + ) + + original = Config.model_validate(original_values) + update = Config.model_validate(update_values) + additional_update = Config.model_validate(additional_update_values) + expected = Config.model_validate(expected_values) + + assert merge_configs([original, update, additional_update]) == expected + + +def test_load_config(tmp_path: Path, mock_home: Path): + """Load config and check that args are passed on to ``discover_config_files``.""" + assert load_config(tmp_path) == Config() + + yaml = YAML() + original_values, update_values, expected_values = tuple( + yaml.load_all(TEST_DATA / "config/config_merge.yml") + ) + additional_update_values = { + "plotting": { + "general": {"default_args_override": {"will_update_arg": "additional update"}} + } + } + expected_values["plotting"]["general"]["default_args_override"]["will_update_arg"] = ( + "additional update" + ) + + file_name = f"{CONFIG_FILE_STEM}.yml" + script_dir = tmp_path / "top_project/project/sub_project" + script_dir.mkdir(parents=True) + + home_config_path = mock_home / file_name + yaml.dump(original_values, home_config_path) + + project_config_path = script_dir.parent / file_name + yaml.dump(update_values, project_config_path) + + sub_project_config_path = script_dir / file_name + yaml.dump(additional_update_values, sub_project_config_path) + + loaded_config = load_config(script_dir) + + assert loaded_config.model_dump() == Config.model_validate(expected_values).model_dump() + assert loaded_config._source_files == [ + home_config_path, + project_config_path, + sub_project_config_path, + ] + + minimal_lookup_config = load_config(script_dir, include_home_dir=False, lookup_depth=1) + + assert ( + minimal_lookup_config.model_dump() + == Config.model_validate(additional_update_values).model_dump() + ) + assert minimal_lookup_config._source_files == [sub_project_config_path] + + +@pytest.fixture +def import_load_script(tmp_path: Path): + """Copy import load script into tmp_path.""" + src_path = TEST_DATA / "config/run_load_config_on_import.py" + dest_path = tmp_path / "run_load_config_on_import.py" + copyfile(src_path, dest_path) + return dest_path + + +def test_load_config_on_import_no_local_config(tmp_path: Path, import_load_script: Path): + """No local config found.""" + subprocess.run([sys.executable, import_load_script]) + + assert (tmp_path / "source_files.json").is_file() is True + assert (tmp_path / "plotting.json").is_file() is True + + source_files = json.loads((tmp_path / "source_files.json").read_text()) + assert ( + any(tmp_path in Path(source_file).parents for source_file in source_files) is False + ), f"{tmp_path=}, {source_files=}" + + +def check_config(tmp_path: Path, plot_config_path: Path): + """Only testing specific aspects makes this test resilient against config pollution. + + Since the config uses a locality hierarchy even if a user has a user config it will be + overridden. + """ + assert (tmp_path / "source_files.json").is_file() is True + assert (tmp_path / "plotting.json").is_file() is True + + source_files = json.loads((tmp_path / "source_files.json").read_text()) + assert ( + any(plot_config_path == Path(source_file) for source_file in source_files) is True + ), f"{tmp_path=}, {source_files=}" + + plotting_dict = json.loads((tmp_path / "plotting.json").read_text()) + + assert ( + plotting_dict["general"]["default_args_override"]["will_update_arg"] == "will change arg" + ) + assert plotting_dict["general"]["default_args_override"]["will_be_kept_arg"] == "general arg" + assert ( + plotting_dict["general"]["axis_label_override"]["will_update_label"] == "will change label" + ) + assert plotting_dict["general"]["axis_label_override"]["will_be_kept_label"] == "general label" + + assert ( + plotting_dict["test_func"]["default_args_override"]["will_update_arg"] == "test_func arg" + ) + assert plotting_dict["test_func"]["default_args_override"]["will_be_added_arg"] == "new arg" + assert ( + plotting_dict["test_func"]["axis_label_override"]["will_update_label"] == "test_func label" + ) + assert plotting_dict["test_func"]["axis_label_override"]["will_be_added_label"] == "new label" + + +def test_load_config_on_import_local_config( + tmp_path: Path, import_load_script: Path, test_config_file: Path +): + """Check that config is properly loaded at import.""" + dest_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + copyfile(test_config_file, dest_path) + + subprocess.run([sys.executable, import_load_script]) + + check_config(tmp_path, dest_path) + + +def test_load_config_on_import_none_root_import( + tmp_path: Path, import_load_script: Path, test_config_file: Path +): + """Also works something from non root ``pyglotaran_extras`` was imported first.""" + dest_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + copyfile(test_config_file, dest_path) + + import_load_script.write_text( + "from pyglotaran_extras.inspect.a_matrix import show_a_matrixes\n" + f"{import_load_script.read_text()}" + ) + + subprocess.run([sys.executable, import_load_script]) + + check_config(tmp_path, dest_path) + + +def test_load_config_on_import_broken_config(tmp_path: Path, import_load_script: Path): + """Broken config does not brake script.""" + src_path = TEST_DATA / "config/broken_config.yml" + dest_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + copyfile(src_path, dest_path) + + res = subprocess.run([sys.executable, import_load_script], text=True, capture_output=True) + assert res.returncode == 0 + + expected_errors = [ + "plotting.general.axis_label_override.target_name", + "plotting.general.axis_label_override.not_allowed_general", + "plotting.general.invalid_kw_general", + "plotting.test_func.axis_label_override.target_name", + "plotting.test_func.axis_label_override.not_allowed_test_func", + "plotting.test_func.invalid_kw_test_func", + "invalid_kw_root", + ] + + assert f"{len(expected_errors)} validation errors for Config" in res.stderr + + assert f"Source path: {dest_path.as_posix()}" in res.stderr + + for error in expected_errors: + assert error in res.stderr, f"Expected error '{error}' not found in stderr" + + +@pytest.mark.usefixtures("mock_config") +def test_create_config_schema(tmp_path: Path, test_config_values: dict[str, Any]): + """A valid config doesn't cause any schema validation errors.""" + + @use_plot_config() + def other( + will_be_kept_arg="default update", + ): + pass + + @use_plot_config() + def test_func( + will_update_arg="default update", + will_be_kept_arg="default keep", + will_be_added_arg="default add", + ): + pass + + json_schema = json.loads(create_config_schema(tmp_path).read_text()) + expected_schema = json.loads( + (TEST_DATA / f"config/{CONFIG_FILE_STEM}.schema.json").read_text() + ) + + assert json_schema == expected_schema + + validator = Draft202012Validator(json_schema) + + assert generator_is_exhausted(validator.iter_errors(test_config_values)) + + +@pytest.mark.usefixtures("mock_config") +def test_create_config_schema_errors(tmp_path: Path): + """A broken config does cause schema validation errors.""" + + @use_plot_config() + def test_func( + will_update_arg="default update", + # we don't define will_be_added_arg and will_be_kept_arg because we want them to cause + # errors + ): + pass + + json_schema = json.loads(create_config_schema(tmp_path).read_text()) + expected_schema = json.loads((TEST_DATA / "config/broken_config.schema.json").read_text()) + + assert json_schema == expected_schema + + validator = Draft202012Validator(json_schema) + + broken_config_dict = YAML().load(TEST_DATA / "config/broken_config.yml") + expected_errors = [ + "Additional properties are not allowed ('invalid_kw_root' was unexpected)", + "Additional properties are not allowed ('invalid_kw_general' was unexpected)", + "Additional properties are not allowed ('will_be_kept_arg' was unexpected)", + "{'will_update_label': 'will change label', 'will_be_kept_label': {'not_allowed_general': " + "False}} is not valid under any of the given schemas", + "Additional properties are not allowed ('invalid_kw_test_func' was unexpected)", + "Additional properties are not allowed ('will_be_added_arg' was unexpected)", + "{'not_allowed_test_func': False} is not valid under any of the given schemas", + ] + for error, expected_error in zip( + validator.iter_errors(broken_config_dict), expected_errors, strict=True + ): + assert error.message.splitlines()[0] == expected_error, [ + error.message.splitlines()[0] for error in validator.iter_errors(broken_config_dict) + ] diff --git a/tests/config/test_plot_config.py b/tests/config/test_plot_config.py new file mode 100644 index 00000000..d74c94b2 --- /dev/null +++ b/tests/config/test_plot_config.py @@ -0,0 +1,683 @@ +"""Tests for ``pyglotaran_extras.config.plot_config``.""" + +from __future__ import annotations + +from functools import wraps +from io import StringIO +from textwrap import dedent +from typing import TYPE_CHECKING +from typing import Any + +import matplotlib.pyplot as plt +import pytest +from jsonschema import ValidationError as SchemaValidationError +from jsonschema import validate +from pydantic import ValidationError as PydanticValidationError +from ruamel.yaml import YAML + +from pyglotaran_extras.config.plot_config import PerFunctionPlotConfig +from pyglotaran_extras.config.plot_config import PlotConfig +from pyglotaran_extras.config.plot_config import PlotLabelOverrideMap +from pyglotaran_extras.config.plot_config import PlotLabelOverrideValue +from pyglotaran_extras.config.plot_config import extract_default_kwargs +from pyglotaran_extras.config.plot_config import find_axes +from pyglotaran_extras.config.plot_config import find_not_user_provided_kwargs +from pyglotaran_extras.config.plot_config import plot_config_context +from pyglotaran_extras.config.plot_config import use_plot_config +from tests import TEST_DATA +from tests.conftest import generator_is_exhausted + +if TYPE_CHECKING: + from typing import Literal + + from matplotlib.axes import Axes + + from pyglotaran_extras.config.config import Config + + +def test_plot_label_override_value_serialization(): + """Short notation is used if axis has default value.""" + assert PlotLabelOverrideValue(target_name="New Label").model_dump() == "New Label" + assert PlotLabelOverrideValue(target_name="New Label", axis="x").model_dump() == { + "target_name": "New Label", + "axis": "x", + } + + +def test_plot_label_override_map(): + """PlotLabelOverrideMap behaves like a mapping and the schema allows short notation.""" + axis_label_override: dict[str, Any] = YAML().load( + StringIO( + dedent( + """ + Old Label: New Label + Old Y Label: + target_name: New Label + axis: y + """ + ) + ) + ) + override_map = PlotLabelOverrideMap(axis_label_override) + + assert len(override_map) == 2 + + assert override_map["Old Label"] == PlotLabelOverrideValue(target_name="New Label") + assert override_map["Old Y Label"] == PlotLabelOverrideValue(target_name="New Label", axis="y") + + override_map_pydantic_init = PlotLabelOverrideMap( + {"Old Label": PlotLabelOverrideValue(target_name="New Label")} + ) + assert override_map_pydantic_init["Old Label"] == PlotLabelOverrideValue( + target_name="New Label" + ) + + validate(instance=axis_label_override, schema=PlotLabelOverrideMap.model_json_schema()) + + for map_item_tuple, expected in zip( + override_map.items(), axis_label_override.items(), strict=True + ): + assert (map_item_tuple[0], map_item_tuple[1].model_dump()) == expected + + with pytest.raises(SchemaValidationError) as execinfo: + validate( + instance={"Old Y Label": {"axis": "y"}}, + schema=PlotLabelOverrideMap.model_json_schema(), + ) + + assert str(execinfo.value).startswith("'target_name' is a required property") + + assert PlotLabelOverrideMap().model_dump() == {} + with pytest.raises(PydanticValidationError) as execinfo: + PlotLabelOverrideMap.model_validate({"invalid": {"invalid": 1}}) + + assert ( + "target_name\n Field required [type=missing, input_value={'invalid': 1}, input_type=dict]" + in str(execinfo.value) + ) + + +@pytest.mark.parametrize( + ("matplotlib_label", "axis_name", "expected"), + [ + ("not_found", "x", None), + ("not_found", "y", None), + ("no_user_newline", "x", "no_user_newline value"), + ("no\n_user_newline", "x", "no_user_newline value"), + ("with_\nuser_\nnewline", "x", "with_user_newline value"), + ("with_user_newline", "x", "with_user_newline value"), + ("x_only", "x", "x_only value"), + ("x_only", "y", None), + ("y_only", "x", None), + ("y_only", "y", "y_only value"), + ], +) +def test_plot_label_override_map_find_axis_label( + matplotlib_label: str, axis_name: Literal["x", "y"], expected: str | None +): + """Finding the correct label is agnostic to newlines injected by matplotlib.""" + override_map = PlotLabelOverrideMap.model_validate( + { + "no_user_newline": "no_user_newline value", + "with_\nuser_\nnewline": "with_user_newline value", + "x_only": {"target_name": "x_only value", "axis": "x"}, + "y_only": {"target_name": "y_only value", "axis": "y"}, + } + ) + assert override_map.find_axis_label(matplotlib_label, axis_name) == expected + + +def test_per_function_plot_config(): + """Initialize with correct defaults and validate correctly.""" + function_config_data: dict[str, Any] = YAML().load( + StringIO( + dedent( + """ + default_args_override: + test_arg: true + axis_label_override: + "Old Label": "New Label" + """ + ) + ) + ) + function_config = PerFunctionPlotConfig.model_validate(function_config_data) + + validate(instance=function_config_data, schema=PerFunctionPlotConfig.model_json_schema()) + + assert function_config.default_args_override["test_arg"] is True + assert function_config.axis_label_override["Old Label"] == PlotLabelOverrideValue( + target_name="New Label" + ) + + with pytest.raises(SchemaValidationError) as execinfo: + validate( + instance={"unknown": 1}, + schema=PerFunctionPlotConfig.model_json_schema(), + ) + + assert str(execinfo.value).startswith( + "Additional properties are not allowed ('unknown' was unexpected)" + ) + + assert PerFunctionPlotConfig().model_dump() == {} + + with pytest.raises(PydanticValidationError) as execinfo: + PerFunctionPlotConfig.model_validate({"unknown": 1}) + + assert ( + "1 validation error for PerFunctionPlotConfig\n" + "unknown\n" + " Extra inputs are not permitted [type=extra_forbidden, input_value=1, input_type=int]" + in str(execinfo.value) + ) + + +def test_per_function_plot_config_merge(): + """Values with same key get updated and other values stay the same.""" + original_config = PerFunctionPlotConfig( + default_args_override={"test_arg": "to_be_changed", "not_updated": "same"}, + axis_label_override={"Old Label": "to_be_changed", "not_updated": "same"}, + ) + update_config = PerFunctionPlotConfig( + default_args_override={"test_arg": "changed"}, + axis_label_override={"Old Label": "changed"}, + ) + + merged_config = original_config.merge(update_config) + + assert merged_config.default_args_override["test_arg"] == "changed" + assert merged_config.default_args_override["not_updated"] == "same" + assert merged_config.axis_label_override["Old Label"] == PlotLabelOverrideValue( + target_name="changed" + ) + assert merged_config.axis_label_override["not_updated"] == PlotLabelOverrideValue( + target_name="same" + ) + + +def test_per_function_plot_find_override_kwargs(): + """Only get kwargs that were not provided by the user and are known.""" + original_config = PerFunctionPlotConfig( + default_args_override={"test_arg": "to_be_changed", "not_updated": "same"}, + ) + + no_override_kwargs = original_config.find_override_kwargs(set()) + assert no_override_kwargs == {} + + override_kwargs = original_config.find_override_kwargs({"not_updated", "unknown_arg"}) + assert override_kwargs == {"not_updated": "same"} + + +def test_per_function_plot_update_axes_labels(): + """Only labels where the axis and current label match get updated.""" + + def create_test_ax() -> Axes: + _, ax = plt.subplots() + ax.set_xlabel("x") + ax.set_ylabel("y") + return ax + + simple_config = PerFunctionPlotConfig(axis_label_override={"x": "new x", "y": "new y"}) + + ax_both = create_test_ax() + simple_config.update_axes_labels(ax_both) + + assert ax_both.get_xlabel() == "new x" + assert ax_both.get_ylabel() == "new y" + + ax_explicit = create_test_ax() + + PerFunctionPlotConfig( + axis_label_override=PlotLabelOverrideMap( + { + "x": PlotLabelOverrideValue(target_name="new x", axis="x"), + "y": PlotLabelOverrideValue(target_name="new y", axis="y"), + } + ) + ).update_axes_labels(ax_explicit) + assert ax_explicit.get_xlabel() == "new x" + assert ax_explicit.get_ylabel() == "new y" + + ax_mismatch = create_test_ax() + + PerFunctionPlotConfig( + axis_label_override=PlotLabelOverrideMap( + { + "x": PlotLabelOverrideValue(target_name="new x", axis="y"), + "y": PlotLabelOverrideValue(target_name="new y", axis="x"), + } + ) + ).update_axes_labels(ax_mismatch) + assert ax_mismatch.get_xlabel() == "x" + assert ax_mismatch.get_ylabel() == "y" + + _, np_axes = plt.subplots(1, 2) + + assert np_axes.shape == (2,) + + np_axes[0].set_xlabel("x") + np_axes[0].set_ylabel("y-keep") + np_axes[1].set_xlabel("x-keep") + np_axes[1].set_ylabel("y") + + simple_config.update_axes_labels(np_axes) + + assert np_axes[0].get_xlabel() == "new x" + assert np_axes[0].get_ylabel() == "y-keep" + assert np_axes[1].get_xlabel() == "x-keep" + assert np_axes[1].get_ylabel() == "new y" + + _, ax0 = plt.subplots(1, 1) + _, ax1 = plt.subplots(1, 1) + + iterable_axes = (ax0, ax1) + + iterable_axes[0].set_xlabel("x") + iterable_axes[0].set_ylabel("y-keep") + iterable_axes[1].set_xlabel("x-keep") + iterable_axes[1].set_ylabel("y") + + simple_config.update_axes_labels(iterable_axes) + + assert iterable_axes[0].get_xlabel() == "new x" + assert iterable_axes[0].get_ylabel() == "y-keep" + assert iterable_axes[1].get_xlabel() == "x-keep" + assert iterable_axes[1].get_ylabel() == "new y" + + +def test_plot_config(): + """Initialize with correct defaults and validate correctly.""" + plot_config_values: dict[str, Any] = YAML().load( + StringIO( + dedent( + """ + general: + default_args_override: + test_arg: true + axis_label_override: + "Old Label": "New Label" + + test_func: + default_args_override: + test_arg: false + axis_label_override: + "Old Y Label": + target_name: "New Y Label" + axis: y + """ + ) + ) + ) + plot_config = PlotConfig.model_validate(plot_config_values) + + assert plot_config.model_extra is not None + + assert plot_config.general == PerFunctionPlotConfig( + default_args_override={"test_arg": True}, + axis_label_override={"Old Label": "New Label"}, + ) + assert plot_config.model_extra["test_func"] == PerFunctionPlotConfig( + default_args_override={"test_arg": False}, + axis_label_override={"Old Y Label": {"target_name": "New Y Label", "axis": "y"}}, + ) + + with pytest.raises(PydanticValidationError) as execinfo: + PerFunctionPlotConfig.model_validate({"test_func": {"unknown": 1}}) + + assert ( + "1 validation error for PerFunctionPlotConfig\n" + "test_func\n" + " Extra inputs are not permitted " + "[type=extra_forbidden, input_value={'unknown': 1}, input_type=dict]" + in str(execinfo.value) + ) + + +def test_plot_config_merge(): + """Merging creates the expected output. + + - Update fields that are present in original and update + - Keep fields that are not present in the update + - Add field that is only present in update + """ + original_values, update_values, expected_values = tuple( + YAML().load_all(StringIO((TEST_DATA / "config/plot_config_merge.yml").read_text())) + ) + original = PlotConfig.model_validate(original_values) + update = PlotConfig.model_validate(update_values) + expected = PlotConfig.model_validate(expected_values) + + assert original.merge(update) == expected + + +def test_plot_config_get_function_config(test_config_values: dict[str, Any]): + """The generated config updates the general config with the test func config. + + - Update fields that are present in general and test_fun config + - Keep fields that are not present in the test_func config + - Add field that is only present in test_func config + """ + plot_config = PlotConfig.model_validate(test_config_values["plotting"]) + + assert plot_config.get_function_config("test_func") == PerFunctionPlotConfig( + default_args_override={ + "will_update_arg": "test_func arg", + "will_be_kept_arg": "general arg", + "will_be_added_arg": "new arg", + }, + axis_label_override={ + "will_update_label": "test_func label", + "will_be_kept_label": "general label", + "will_be_added_label": "new label", + }, + ) + + +def test_extract_default_kwargs(): + """Extract argument that can be passed as kwargs.""" + + def func( + pos_arg: str, + pos_arg_default: str = "pos_arg_default", + /, + normal_arg: str = "normal_arg", + *, + kw_only_arg: int = 1, + ): + r"""Test function. + + Parameters + ---------- + pos_arg : str + Not extracted + pos_arg_default : str + Not extracted. Defaults to "pos_arg_default" + normal_arg : str + A normal arg. Defaults to "normal_arg". + kw_only_arg : int + A keyword only arg with new line and escaped character in the docstring (\\nu). + Defaults to 1. + """ + + assert extract_default_kwargs(func, ()) == { + "normal_arg": { + "default": "normal_arg", + "annotation": "str", + "docstring": 'A normal arg. Defaults to "normal_arg".', + }, + "kw_only_arg": { + "default": 1, + "annotation": "int", + "docstring": ( + r"A keyword only arg with new line and escaped character in the docstring (\\nu)." + " Defaults to 1." + ), + }, + } + + assert extract_default_kwargs(func, ("kw_only_arg",)) == { + "normal_arg": { + "default": "normal_arg", + "annotation": "str", + "docstring": 'A normal arg. Defaults to "normal_arg".', + }, + } + + def no_annotation(foo="bar"): + pass + + assert extract_default_kwargs(no_annotation, ()) == { + "foo": {"default": "bar", "annotation": "object", "docstring": None} + } + + +def test_find_not_user_provided_kwargs(): + """Only find kwarg names for none user passed kwargs.""" + result = None + + def dec(func): + default_kwargs = extract_default_kwargs(func, ()) + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal result + arg_names = func.__code__.co_varnames[: len(args)] + result = find_not_user_provided_kwargs(default_kwargs, arg_names, kwargs) + return func(*args, *kwargs) + + return wrapper + + @dec + def func( + pos_arg: str, + arg1: str = "arg1", + arg2: str = "arg2", + *, + kwarg1: int = 1, + kwarg2: int = 1, + ): + return 1 + + assert func("foo", "bar", kwarg2=2) == 1 + assert result == {"arg2", "kwarg1"} + + +def test_find_axes(): + """Get axes value from iterable of values.""" + + base_values = ["foo", True, 1.5] + + assert generator_is_exhausted(find_axes(base_values)) is True + + _, ax = plt.subplots() + single_ax_gen = find_axes([*base_values, ax]) + + assert next(single_ax_gen) is ax + assert generator_is_exhausted(single_ax_gen) is True + + _, np_axes = plt.subplots(1, 2) + + assert np_axes.shape == (2,) + + np_axes_gen = find_axes([*base_values, np_axes]) + + assert next(np_axes_gen) is np_axes[0] + assert next(np_axes_gen) is np_axes[1] + assert generator_is_exhausted(np_axes_gen) is True + + _, ax1 = plt.subplots() + iterable_axes = (ax, ax1) + iterable_axes_gen = find_axes([*base_values, iterable_axes]) + + assert next(iterable_axes_gen) is ax + assert next(iterable_axes_gen) is ax1 + assert generator_is_exhausted(iterable_axes_gen) is True + + multiple_axes_gen = find_axes([*base_values, ax, ax1]) + + assert next(multiple_axes_gen) is ax + assert next(multiple_axes_gen) is ax1 + assert generator_is_exhausted(multiple_axes_gen) is True + + +def test_use_plot_config(mock_config: tuple[Config, dict[str, Any]]): + """Config is applied to functions with the ``use_plot_config`` decorator.""" + _, registry = mock_config + + assert registry == {} + + @use_plot_config() + def test_func( + will_update_arg="default update", + will_be_kept_arg="default keep", + will_be_added_arg="default add", + not_in_config="not_in_config", + ): + kwargs = { + "will_update_arg": will_update_arg, + "will_be_kept_arg": will_be_kept_arg, + "will_be_added_arg": will_be_added_arg, + "not_in_config": not_in_config, + } + fig, (ax1, ax2) = plt.subplots(1, 2) + ax1.set_xlabel("will_update_label") + ax1.set_ylabel("will_be_kept_label") + ax2.set_xlabel("will_be_added_label") + ax2.set_ylabel("default") + return fig, (ax1, ax2), kwargs + + assert "test_func" in registry + + _, (ax1_test_func, ax2_test_func), kwargs_test_func_no_user_args = test_func() + + assert ax1_test_func.get_xlabel() == "test_func label" + assert ax1_test_func.get_ylabel() == "general label" + assert ax2_test_func.get_xlabel() == "new label" + assert ax2_test_func.get_ylabel() == "default" + + assert kwargs_test_func_no_user_args["will_update_arg"] == "test_func arg" + assert kwargs_test_func_no_user_args["will_be_kept_arg"] == "general arg" + assert kwargs_test_func_no_user_args["will_be_added_arg"] == "new arg" + assert kwargs_test_func_no_user_args["not_in_config"] == "not_in_config" + + _, _, kwargs_test_func_user_args = test_func( + will_update_arg="set by user", will_be_added_arg="added by user" + ) + + assert kwargs_test_func_user_args["will_update_arg"] == "set by user" + assert kwargs_test_func_user_args["will_be_kept_arg"] == "general arg" + assert kwargs_test_func_user_args["will_be_added_arg"] == "added by user" + assert kwargs_test_func_no_user_args["not_in_config"] == "not_in_config" + + _, axes = plt.subplots(1, 2) + + @use_plot_config() + def axes_iterable_arg( + axes: tuple[Axes, Axes], + ): + (ax1, ax2) = axes + ax1.set_xlabel("will_update_label") + ax1.set_ylabel("will_be_kept_label") + ax2.set_xlabel("will_be_added_label") + ax2.set_ylabel("default") + return (ax1, ax2) + + assert "axes_iterable_arg" in registry + + axes_iterable_arg((axes[0], axes[1])) + + assert axes[0].get_xlabel() == "will change label" + assert axes[0].get_ylabel() == "general label" + assert axes[1].get_xlabel() == "will_be_added_label" + assert axes[1].get_ylabel() == "default" + + _, (ax1_arg, ax2_arg) = plt.subplots(1, 2) + + @use_plot_config() + def multiple_axes_args( + ax1: Axes, + ax2: Axes, + ): + ax1.set_xlabel("will_update_label") + ax1.set_ylabel("will_be_kept_label") + ax2.set_xlabel("will_be_added_label") + ax2.set_ylabel("default") + return (ax1, ax2) + + assert "multiple_axes_args" in registry + + multiple_axes_args(ax1_arg, ax2_arg) + + assert ax1_arg.get_xlabel() == "will change label" + assert ax1_arg.get_ylabel() == "general label" + assert ax2_arg.get_xlabel() == "will_be_added_label" + assert ax2_arg.get_ylabel() == "default" + + # Integration test that ``PlotLabelOverrideMap.find_axis_label`` is used + _, ax_newline = plt.subplots() + + @use_plot_config() + def newline_label( + ax: Axes, + ): + ax.set_xlabel("will_\nupdate_label") + ax.set_ylabel("will_be_\nkept_label") + return ax + + assert "newline_label" in registry + + newline_label(ax_newline) + + assert ax_newline.get_xlabel() == "will change label" + assert ax_newline.get_ylabel() == "general label" + + +@pytest.mark.usefixtures("mock_config") +def test_plot_config_context(): + """Context overrides resolved config values of the function.""" + import pyglotaran_extras + + source_file = pyglotaran_extras.CONFIG._source_files[0] + original_test_func_config = pyglotaran_extras.CONFIG.plotting.get_function_config( + "test_func" + ).model_copy(deep=True) + + plot_config = PerFunctionPlotConfig( + default_args_override={ + "will_be_added_arg": "test_func arg overridden by context arg", + "added_by_context_arg": "added by context arg", + }, + axis_label_override={ + "will_be_added_label": "test_func arg overridden by context label", + "added_by_context_label": "added by context label", + }, + ) + + _, (ax1_arg, ax2_arg) = plt.subplots(1, 2) + + @use_plot_config() + def test_func( + ax1: Axes, + ax2: Axes, + ): + ax1.set_xlabel("will_update_label") + ax1.set_ylabel("will_be_kept_label") + ax2.set_xlabel("will_be_added_label") + ax2.set_ylabel("default") + return (ax1, ax2) + + with plot_config_context(plot_config): + test_func_config = pyglotaran_extras.CONFIG.plotting.get_function_config("test_func") + + # Force reload in use_plot_config by changing mtime of the file + source_file.write_bytes(source_file.read_bytes()) + + test_func(ax1_arg, ax2_arg) + + assert hasattr(pyglotaran_extras.CONFIG.plotting, "__context_config") is False + + assert test_func_config == PerFunctionPlotConfig( + default_args_override={ + "will_update_arg": "test_func arg", + "will_be_kept_arg": "general arg", + "will_be_added_arg": "test_func arg overridden by context arg", + "added_by_context_arg": "added by context arg", + }, + axis_label_override={ + "will_update_label": "test_func label", + "will_be_kept_label": "general label", + "will_be_added_label": "test_func arg overridden by context label", + "added_by_context_label": "added by context label", + }, + ) + + assert ax1_arg.get_xlabel() == "test_func label" + assert ax1_arg.get_ylabel() == "general label" + assert ax2_arg.get_xlabel() == "test_func arg overridden by context label" + assert ax2_arg.get_ylabel() == "default" + + assert ( + pyglotaran_extras.CONFIG.plotting.get_function_config("test_func") + == original_test_func_config + ) diff --git a/tests/config/test_utils.py b/tests/config/test_utils.py new file mode 100644 index 00000000..48e89afc --- /dev/null +++ b/tests/config/test_utils.py @@ -0,0 +1,54 @@ +"""Tests for ``pyglotaran_extras.config.utils``.""" + +from __future__ import annotations + +from textwrap import dedent +from typing import Any + +from IPython.core.formatters import format_display_data +from pydantic import BaseModel + +from pyglotaran_extras.config.utils import add_yaml_repr +from pyglotaran_extras.config.utils import to_yaml_str + + +class UtilTestClass(BaseModel): + """Class with test data by default.""" + + str_attr: str = "str_val" + int_attr: int = 2 + dict_attr: dict[str, Any] = {"key1": 1, "key2": None} + list_attr: list[Any] = ["str", 1, None] + + +EXPECTED_YAML_STR = dedent( + """\ + str_attr: str_val + int_attr: 2 + dict_attr: + key1: 1 + key2: null + list_attr: + - str + - 1 + - null + """ +) + + +def test_to_yaml_str(): + """Created yaml str has expected format.""" + test_instance = UtilTestClass() + assert to_yaml_str(test_instance) == EXPECTED_YAML_STR + + +def test_add_yaml_repr(): + """Added methods behave as expected when converting to string and rendering in ipython.""" + test_instance = add_yaml_repr(UtilTestClass)() + assert str(test_instance) == EXPECTED_YAML_STR + + rendered_result = format_display_data(test_instance)[0] + + assert "text/markdown" in rendered_result + assert rendered_result["text/markdown"] == f"```yaml\n{EXPECTED_YAML_STR}\n```" + assert rendered_result["text/plain"] == repr(test_instance) diff --git a/tests/conftest.py b/tests/conftest.py index 18769cd3..5639a412 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,13 @@ from __future__ import annotations +import sys +from contextlib import contextmanager +from pathlib import Path +from shutil import copyfile +from typing import TYPE_CHECKING +from typing import Any + # isort: off # Hack around https://github.com/pydata/xarray/issues/7259 which also affects pyglotaran <= 0.7.0 import numpy # noqa: F401 @@ -10,19 +17,59 @@ from dataclasses import replace +import matplotlib.pyplot as plt import pytest from glotaran.optimization.optimize import optimize from glotaran.testing.simulated_data.parallel_spectral_decay import SCHEME as SCHEME_PAR from glotaran.testing.simulated_data.sequential_spectral_decay import SCHEME as SCHEME_SEQ +from ruamel.yaml import YAML +from pyglotaran_extras.config.config import CONFIG_FILE_STEM +from pyglotaran_extras.config.config import Config from pyglotaran_extras.io.setup_case_study import get_script_dir +from tests import TEST_DATA + +if TYPE_CHECKING: + from collections.abc import Generator + + +@contextmanager +def monkeypatch_all(monkeypatch: pytest.MonkeyPatch, name: str, value: Any): + """Context to monkeypatch all usages across modules.""" + with monkeypatch.context() as m: + for module_name, module in sys.modules.items(): + if module_name.startswith("pyglotaran_extras") and hasattr(module, name): + m.setattr(module, name, value) + yield + +def generator_is_exhausted(generator: Generator) -> bool: + """Check if ``generator`` is exhausted. -def wrapped_get_script_dir(): + Parameters + ---------- + generator : Generator + Generator to check. + + Returns + ------- + bool + """ + is_empty = object() + return next(generator, is_empty) is is_empty + + +def wrapped_get_script_dir() -> Path: """Test function for calls to get_script_dir used inside of other functions.""" return get_script_dir(nesting=1) +@pytest.fixture(autouse=True) +def _close_matplotlib_figures(): + """Close all figures after each test function.""" + plt.close() + + @pytest.fixture(scope="session") def result_parallel_spectral_decay(): """Test result from ``glotaran.testing.simulated_data.parallel_spectral_decay``.""" @@ -35,3 +82,49 @@ def result_sequential_spectral_decay(): """Test result from ``glotaran.testing.simulated_data.sequential_spectral_decay``.""" scheme = replace(SCHEME_SEQ, maximum_number_function_evaluations=1) return optimize(scheme) + + +@pytest.fixture +def mock_home(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + """Mock ``pathlib.Path.home`` to return ``tmp_path/"home"``.""" + + mock_home_path = tmp_path / "home" + mock_home_path.mkdir() + + with monkeypatch.context() as m: + m.setattr(Path, "home", lambda: mock_home_path) + MockPath = Path # noqa: N806 + with monkeypatch_all(monkeypatch, "Path", MockPath): + yield mock_home_path + + +@pytest.fixture +def test_config_file() -> Path: + """Path to the test config file.""" + return TEST_DATA / "config/pygta_config.yml" + + +@pytest.fixture +def test_config_values(test_config_file: Path) -> dict[str, Any]: + """Read test config into dict.""" + return YAML().load(test_config_file) + + +@pytest.fixture +def mock_config( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, test_config_file: Path +) -> Generator[tuple[Config, dict[str, Any]], None, None]: + """Mock config with test config and empty function registry.""" + dest_path = tmp_path / f"{CONFIG_FILE_STEM}.yml" + copyfile(test_config_file, dest_path) + config = Config() + config.load(dest_path) + mock_registry: dict[str, Any] = {} + project_dir = tmp_path / "project" + project_dir.mkdir(parents=True, exist_ok=True) + with ( + monkeypatch_all(monkeypatch, "CONFIG", config), + monkeypatch_all(monkeypatch, "SCRIPT_DIR", project_dir), + monkeypatch_all(monkeypatch, "__PlotFunctionRegistry", mock_registry), + ): + yield config, mock_registry diff --git a/tests/data/config/broken_config.schema.json b/tests/data/config/broken_config.schema.json new file mode 100644 index 00000000..3b7a107d --- /dev/null +++ b/tests/data/config/broken_config.schema.json @@ -0,0 +1,147 @@ +{ + "$defs": { + "PerFunctionPlotConfig": { + "additionalProperties": false, + "description": "Per function plot configuration.", + "properties": { + "default_args_override": { + "description": "Default arguments to use if not specified in function call.", + "title": "Default Args Override", + "type": "object", + "properties": { + "will_update_arg": { + "default": "default update", + "title": "Will Update Arg" + } + }, + "additionalProperties": false + }, + "axis_label_override": { + "anyOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideMap" + }, + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + } + ], + "title": "Axis Label Override" + } + }, + "title": "PerFunctionPlotConfig", + "type": "object" + }, + "PlotConfig": { + "additionalProperties": true, + "description": "Config for plot functions including default args and label overrides.", + "properties": { + "general": { + "allOf": [ + { + "$ref": "#/$defs/PerFunctionPlotConfig" + } + ], + "description": "Config that gets applied to all functions if not specified otherwise." + }, + "test_func": { + "allOf": [ + { + "$ref": "#/$defs/TestFuncConfig" + } + ] + } + }, + "title": "PlotConfig", + "type": "object" + }, + "PlotLabelOverrideMap": { + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideValue" + }, + { + "type": "string" + } + ] + }, + "description": "Mapping to override axis labels.", + "title": "PlotLabelOverrideMap", + "type": "object" + }, + "PlotLabelOverrideValue": { + "additionalProperties": false, + "description": "Value of ``PlotLabelOverrideMap``.", + "properties": { + "target_name": { + "title": "Target Name", + "type": "string" + }, + "axis": { + "default": "both", + "enum": ["x", "y", "both"], + "title": "Axis", + "type": "string" + } + }, + "required": ["target_name"], + "title": "PlotLabelOverrideValue", + "type": "object" + }, + "TestFuncKwargs": { + "additionalProperties": false, + "description": "Default arguments to use for ``test_func``, if not specified in function call.", + "properties": { + "will_update_arg": { + "default": "default update", + "title": "Will Update Arg" + } + }, + "title": "TestFuncKwargs", + "type": "object" + }, + "TestFuncConfig": { + "additionalProperties": false, + "description": "Plot function configuration specific to ``test_func`` (overrides values in general).", + "properties": { + "default_args_override": { + "allOf": [ + { + "$ref": "#/$defs/TestFuncKwargs" + } + ], + "default": {} + }, + "axis_label_override": { + "allOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideMap" + } + ], + "default": {} + } + }, + "title": "TestFuncConfig", + "type": "object" + } + }, + "additionalProperties": false, + "description": "Main configuration class.", + "properties": { + "plotting": { + "allOf": [ + { + "$ref": "#/$defs/PlotConfig" + } + ], + "default": { + "general": {} + } + } + }, + "title": "Config", + "type": "object" +} diff --git a/tests/data/config/broken_config.yml b/tests/data/config/broken_config.yml new file mode 100644 index 00000000..6f9e0fcb --- /dev/null +++ b/tests/data/config/broken_config.yml @@ -0,0 +1,24 @@ +# yaml-language-server: $schema=broken_config.schema.json + +plotting: + general: + default_args_override: + will_update_arg: will change arg + will_be_kept_arg: general arg + axis_label_override: + will_update_label: will change label + will_be_kept_label: + not_allowed_general: False + invalid_kw_general: True + + test_func: + default_args_override: + will_update_arg: test_func arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: test_func label + will_be_added_label: + not_allowed_test_func: False + invalid_kw_test_func: True + +invalid_kw_root: True diff --git a/tests/data/config/config_merge.yml b/tests/data/config/config_merge.yml new file mode 100644 index 00000000..6fa1f9b5 --- /dev/null +++ b/tests/data/config/config_merge.yml @@ -0,0 +1,74 @@ +--- +# Original +plotting: + general: + default_args_override: + will_update_arg: will change arg + will_be_kept_arg: original arg + axis_label_override: + will_update_label: will change label + will_be_kept_label: original label + + test_func: + default_args_override: + will_update_arg: will change arg extra + will_be_kept_arg: original arg extra + axis_label_override: + will_update_label: will change label extra + will_be_kept_label: original label extra + + only_in_original_extra: + default_args_override: + arg: only in original +--- +# Update +plotting: + general: + default_args_override: + will_update_arg: changed arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: changed label + will_be_added_label: new label + + test_func: + default_args_override: + will_update_arg: changed arg extra + will_be_added_arg: new arg extra + axis_label_override: + will_update_label: changed label extra + will_be_added_label: new label extra + + only_in_update_extra: + default_args_override: + arg: only in update +--- +# Expected +plotting: + general: + default_args_override: + will_update_arg: changed arg + will_be_kept_arg: original arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: changed label + will_be_kept_label: original label + will_be_added_label: new label + + test_func: + default_args_override: + will_update_arg: changed arg extra + will_be_kept_arg: original arg extra + will_be_added_arg: new arg extra + axis_label_override: + will_update_label: changed label extra + will_be_kept_label: original label extra + will_be_added_label: new label extra + + only_in_original_extra: + default_args_override: + arg: only in original + + only_in_update_extra: + default_args_override: + arg: only in update diff --git a/tests/data/config/plot_config_merge.yml b/tests/data/config/plot_config_merge.yml new file mode 100644 index 00000000..9606f1ed --- /dev/null +++ b/tests/data/config/plot_config_merge.yml @@ -0,0 +1,71 @@ +--- +# Original +general: + default_args_override: + will_update_arg: will change arg + will_be_kept_arg: original arg + axis_label_override: + will_update_label: will change label + will_be_kept_label: original label + +test_func: + default_args_override: + will_update_arg: will change arg extra + will_be_kept_arg: original arg extra + axis_label_override: + will_update_label: will change label extra + will_be_kept_label: original label extra + +only_in_original_extra: + default_args_override: + arg: only in original +--- +# Update +general: + default_args_override: + will_update_arg: changed arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: changed label + will_be_added_label: new label + +test_func: + default_args_override: + will_update_arg: changed arg extra + will_be_added_arg: new arg extra + axis_label_override: + will_update_label: changed label extra + will_be_added_label: new label extra + +only_in_update_extra: + default_args_override: + arg: only in update +--- +# Expected +general: + default_args_override: + will_update_arg: changed arg + will_be_kept_arg: original arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: changed label + will_be_kept_label: original label + will_be_added_label: new label + +test_func: + default_args_override: + will_update_arg: changed arg extra + will_be_kept_arg: original arg extra + will_be_added_arg: new arg extra + axis_label_override: + will_update_label: changed label extra + will_be_kept_label: original label extra + will_be_added_label: new label extra + +only_in_original_extra: + default_args_override: + arg: only in original + +only_in_update_extra: + default_args_override: + arg: only in update diff --git a/tests/data/config/pygta_config.schema.json b/tests/data/config/pygta_config.schema.json new file mode 100644 index 00000000..fee60dec --- /dev/null +++ b/tests/data/config/pygta_config.schema.json @@ -0,0 +1,206 @@ +{ + "$defs": { + "PerFunctionPlotConfig": { + "additionalProperties": false, + "description": "Per function plot configuration.", + "properties": { + "default_args_override": { + "description": "Default arguments to use if not specified in function call.", + "title": "Default Args Override", + "type": "object", + "properties": { + "will_be_kept_arg": { + "default": "default keep", + "title": "Will Be Kept Arg" + }, + "will_update_arg": { + "default": "default update", + "title": "Will Update Arg" + }, + "will_be_added_arg": { + "default": "default add", + "title": "Will Be Added Arg" + } + }, + "additionalProperties": false + }, + "axis_label_override": { + "anyOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideMap" + }, + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + } + ], + "title": "Axis Label Override" + } + }, + "title": "PerFunctionPlotConfig", + "type": "object" + }, + "PlotConfig": { + "additionalProperties": true, + "description": "Config for plot functions including default args and label overrides.", + "properties": { + "general": { + "allOf": [ + { + "$ref": "#/$defs/PerFunctionPlotConfig" + } + ], + "description": "Config that gets applied to all functions if not specified otherwise." + }, + "other": { + "allOf": [ + { + "$ref": "#/$defs/OtherConfig" + } + ] + }, + "test_func": { + "allOf": [ + { + "$ref": "#/$defs/TestFuncConfig" + } + ] + } + }, + "title": "PlotConfig", + "type": "object" + }, + "PlotLabelOverrideMap": { + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideValue" + }, + { + "type": "string" + } + ] + }, + "description": "Mapping to override axis labels.", + "title": "PlotLabelOverrideMap", + "type": "object" + }, + "PlotLabelOverrideValue": { + "additionalProperties": false, + "description": "Value of ``PlotLabelOverrideMap``.", + "properties": { + "target_name": { + "title": "Target Name", + "type": "string" + }, + "axis": { + "default": "both", + "enum": ["x", "y", "both"], + "title": "Axis", + "type": "string" + } + }, + "required": ["target_name"], + "title": "PlotLabelOverrideValue", + "type": "object" + }, + "OtherKwargs": { + "additionalProperties": false, + "description": "Default arguments to use for ``other``, if not specified in function call.", + "properties": { + "will_be_kept_arg": { + "default": "default update", + "title": "Will Be Kept Arg" + } + }, + "title": "OtherKwargs", + "type": "object" + }, + "OtherConfig": { + "additionalProperties": false, + "description": "Plot function configuration specific to ``other`` (overrides values in general).", + "properties": { + "default_args_override": { + "allOf": [ + { + "$ref": "#/$defs/OtherKwargs" + } + ], + "default": {} + }, + "axis_label_override": { + "allOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideMap" + } + ], + "default": {} + } + }, + "title": "OtherConfig", + "type": "object" + }, + "TestFuncKwargs": { + "additionalProperties": false, + "description": "Default arguments to use for ``test_func``, if not specified in function call.", + "properties": { + "will_update_arg": { + "default": "default update", + "title": "Will Update Arg" + }, + "will_be_kept_arg": { + "default": "default keep", + "title": "Will Be Kept Arg" + }, + "will_be_added_arg": { + "default": "default add", + "title": "Will Be Added Arg" + } + }, + "title": "TestFuncKwargs", + "type": "object" + }, + "TestFuncConfig": { + "additionalProperties": false, + "description": "Plot function configuration specific to ``test_func`` (overrides values in general).", + "properties": { + "default_args_override": { + "allOf": [ + { + "$ref": "#/$defs/TestFuncKwargs" + } + ], + "default": {} + }, + "axis_label_override": { + "allOf": [ + { + "$ref": "#/$defs/PlotLabelOverrideMap" + } + ], + "default": {} + } + }, + "title": "TestFuncConfig", + "type": "object" + } + }, + "additionalProperties": false, + "description": "Main configuration class.", + "properties": { + "plotting": { + "allOf": [ + { + "$ref": "#/$defs/PlotConfig" + } + ], + "default": { + "general": {} + } + } + }, + "title": "Config", + "type": "object" +} diff --git a/tests/data/config/pygta_config.yml b/tests/data/config/pygta_config.yml new file mode 100644 index 00000000..7e7b9ca4 --- /dev/null +++ b/tests/data/config/pygta_config.yml @@ -0,0 +1,17 @@ +# yaml-language-server: $schema=pygta_config.schema.json + +plotting: + general: + default_args_override: + will_update_arg: will change arg + will_be_kept_arg: general arg + axis_label_override: + will_update_label: will change label + will_be_kept_label: general label + test_func: + default_args_override: + will_update_arg: test_func arg + will_be_added_arg: new arg + axis_label_override: + will_update_label: test_func label + will_be_added_label: new label diff --git a/tests/data/config/run_load_config_on_import.py b/tests/data/config/run_load_config_on_import.py new file mode 100644 index 00000000..63c4f49f --- /dev/null +++ b/tests/data/config/run_load_config_on_import.py @@ -0,0 +1,17 @@ +import json +from pathlib import Path + +# This is needed to simulate the case were something else is imported from ``pyglotaran_extras`` +# Before the config is +from pyglotaran_extras import plot_overview # noqa: F401 + +# Added before the import so imports don't get sorted +HERE = Path(__file__).parent + +# We just import the config directly to read out the values +from pyglotaran_extras import CONFIG # noqa: E402 + +(HERE / "source_files.json").write_text( + json.dumps([source_file.as_posix() for source_file in CONFIG._source_files]) +) +(HERE / "plotting.json").write_text(CONFIG.plotting.model_dump_json()) diff --git a/tests/deprecation/test_deprecation_utils.py b/tests/deprecation/test_deprecation_utils.py index 9878de93..41e8f481 100644 --- a/tests/deprecation/test_deprecation_utils.py +++ b/tests/deprecation/test_deprecation_utils.py @@ -8,7 +8,7 @@ import pytest import pyglotaran_extras -from pyglotaran_extras.deprecation.deprecation_utils import OverDueDeprecationError +from pyglotaran_extras.deprecation.deprecation_utils import OverdueDeprecationError from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning from pyglotaran_extras.deprecation.deprecation_utils import check_overdue from pyglotaran_extras.deprecation.deprecation_utils import parse_version @@ -37,7 +37,7 @@ ) -@pytest.fixture() +@pytest.fixture def _pyglotaran_extras_0_3_0(monkeypatch: MonkeyPatch): """Mock pyglotaran_extras version to always be 0.3.0 for the test.""" monkeypatch.setattr( @@ -47,7 +47,7 @@ def _pyglotaran_extras_0_3_0(monkeypatch: MonkeyPatch): ) -@pytest.fixture() +@pytest.fixture def _pyglotaran_extras_1_0_0(monkeypatch: MonkeyPatch): """Mock pyglotaran_extras version to always be 1.0.0 for the test.""" monkeypatch.setattr( @@ -98,7 +98,7 @@ def test_check_overdue_no_raise(monkeypatch: MonkeyPatch): @pytest.mark.usefixtures("_pyglotaran_extras_1_0_0") def test_check_overdue_raises(monkeypatch: MonkeyPatch): """Current version is equal or bigger than drop_version.""" - with pytest.raises(OverDueDeprecationError) as excinfo: + with pytest.raises(OverdueDeprecationError) as excinfo: check_overdue( deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, to_be_removed_in_version="0.6.0", @@ -126,7 +126,7 @@ def test_warn_deprecated(): def test_warn_deprecated_overdue_deprecation(monkeypatch: MonkeyPatch): """Current version is equal or bigger than drop_version.""" - with pytest.raises(OverDueDeprecationError) as excinfo: + with pytest.raises(OverdueDeprecationError) as excinfo: warn_deprecated( deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, new_qual_name_usage=NEW_QUAL_NAME, @@ -145,7 +145,7 @@ def test_warn_deprecated_no_overdue_deprecation_on_dev(monkeypatch: MonkeyPatch) lambda: "0.6.0-dev", ) - with pytest.raises(OverDueDeprecationError): + with pytest.raises(OverdueDeprecationError): warn_deprecated( deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, new_qual_name_usage=NEW_QUAL_NAME, diff --git a/tests/io/test_load_data.py b/tests/io/test_load_data.py index 2edc3679..8c98ae95 100644 --- a/tests/io/test_load_data.py +++ b/tests/io/test_load_data.py @@ -79,7 +79,7 @@ def test_load_data( filtered_warnings = filter_warnings(recwarn) assert len(filtered_warnings) == 1 - assert filtered_warnings[0].category == UserWarning + assert filtered_warnings[0].category is UserWarning assert filtered_warnings[0].message.args[0] == MULTI_DATASET_WARING # type:ignore[union-attr] assert Path(filtered_warnings[0].filename) == Path(__file__) @@ -94,7 +94,7 @@ def wrapped_call(result: Result): assert len(filtered_warnings) == 2 - assert filtered_warnings[1].category == UserWarning + assert filtered_warnings[1].category is UserWarning assert filtered_warnings[1].message.args[0] == MULTI_DATASET_WARING # type:ignore[union-attr] assert Path(filtered_warnings[1].filename) == Path(__file__) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index f98ad447..b3e5d509 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -41,7 +41,7 @@ ) def test_add_cycler_if_not_none_single_axis(cycler: Cycler | None, expected_cycler: cycle): """Default cycler if None and cycler otherwise on a single axis.""" - ax = plt.subplot() + _, ax = plt.subplots() add_cycler_if_not_none(ax, cycler) for _ in range(10): @@ -197,8 +197,6 @@ def test_add_subplot_labels_assignment( assert [ax.texts[0].get_anncoords() for ax in axes.flatten()] == [label_coords] * 4 assert [ax.texts[0].get_fontsize() for ax in axes.flatten()] == [fontsize] * 4 - plt.close() - @pytest.mark.parametrize(("label_format_template", "expected"), [("{})", "1)"), ("({})", "(1)")]) def test_add_subplot_labels_label_format_template(label_format_template: str, expected: str):