diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3ac808be..2478c3ee0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,8 +77,8 @@ repos: rev: 6.1.1 hooks: - id: pydocstyle - files: "^glotaran/(plugin_system|utils|deprecation)" - exclude: "docs|tests?" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" + exclude: "docs|tests?/" # this is needed due to the following issue: # https://github.com/PyCQA/pydocstyle/issues/368 args: [--ignore-decorators=wrap_func_as_method] @@ -87,14 +87,14 @@ repos: rev: v1.8.0 hooks: - id: darglint - files: "^glotaran/(plugin_system|utils|deprecation)" - exclude: "docs|tests?" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" + exclude: "docs|tests?/" - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910 hooks: - id: mypy - files: "^glotaran/(plugin_system|utils|deprecation)" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" exclude: "docs" additional_dependencies: [types-all] diff --git a/benchmark/pytest/analysis/test_problem.py b/benchmark/pytest/analysis/test_problem.py index 9c148818f..16aa886c3 100644 --- a/benchmark/pytest/analysis/test_problem.py +++ b/benchmark/pytest/analysis/test_problem.py @@ -13,6 +13,7 @@ from glotaran.model import megacomplex from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +from glotaran.testing.plugin_system import monkeypatch_plugin_registry if TYPE_CHECKING: from glotaran.model import DatasetModel @@ -53,6 +54,7 @@ def finalize_data( pass +@monkeypatch_plugin_registry(test_megacomplex={"benchmark": BenchmarkMegacomplex}) def setup_model(index_dependent): model_dict = { "megacomplex": {"m1": {"is_index_dependent": index_dependent}}, diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index e45ad6cca..540c59824 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -16,7 +16,6 @@ from glotaran.io import save_dataset from glotaran.io import save_parameters from glotaran.model import Model -from glotaran.model import get_megacomplex from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme @@ -66,18 +65,7 @@ def load_model(self, file_name: str) -> Model: if "megacomplex" not in spec: raise ValueError("No megacomplex defined in model") - megacomplex_types = { - m["type"]: get_megacomplex(m["type"]) - for m in spec["megacomplex"].values() - if "type" in m - } - if default_megacomplex is not None: - megacomplex_types[default_megacomplex] = get_megacomplex(default_megacomplex) - del spec["default-megacomplex"] - - return Model.from_dict( - spec, megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex - ) + return Model.from_dict(spec, megacomplex_types=None, default_megacomplex_type=None) def load_parameters(self, file_name: str) -> ParameterGroup: diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index f788a6b90..939fb6c1b 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -6,11 +6,10 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate -from glotaran.builtin.megacomplexes.decay import DecayMegacomplex -from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +from glotaran.testing.model_generators import SimpleModelGenerator def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): @@ -28,20 +27,9 @@ class DecayModel(Model): def from_dict( cls, model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, ): - defaults: dict[str, type[Megacomplex]] = { - "decay": DecayMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) + model_dict = {**model_dict, "default-megacomplex": "decay"} + return super().from_dict(model_dict) class OneComponentOneChannel: @@ -136,62 +124,16 @@ class OneComponentOneChannelGaussianIrf: class ThreeComponentParallel: - model = DecayModel.from_dict( - { - "initial_concentration": { - "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["j.1", "j.1", "j.1"]}, - }, - "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s2", "s1"): "kinetic.1", - ("s3", "s2"): "kinetic.2", - ("s3", "s3"): "kinetic.3", - } - } - }, - "irf": { - "irf1": { - "type": "multi-gaussian", - "center": ["irf.center"], - "width": ["irf.width"], - }, - }, - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "irf": "irf1", - "megacomplex": ["mc1"], - }, - }, - } + generator = SimpleModelGenerator( + rates=[300e-3, 500e-4, 700e-5], + irf={"center": 1.3, "width": 7.8}, + k_matrix="parallel", ) + model, initial_parameters = generator.model_and_parameters + + generator.rates = [301e-3, 502e-4, 705e-5] + wanted_parameters = generator.parameters - initial_parameters = ParameterGroup.from_dict( - { - "kinetic": [ - ["1", 300e-3], - ["2", 500e-4], - ["3", 700e-5], - ], - "irf": [["center", 1.3], ["width", 7.8]], - "j": [["1", 1, {"vary": False, "non-negative": False}]], - } - ) - wanted_parameters = ParameterGroup.from_dict( - { - "kinetic": [ - ["1", 301e-3], - ["2", 502e-4], - ["3", 705e-5], - ], - "irf": [["center", 1.3], ["width", 7.8]], - "j": [["1", 1, {"vary": False, "non-negative": False}]], - } - ) time = np.arange(-10, 100, 1.5) pixel = np.arange(600, 750, 10) diff --git a/glotaran/model/model.py b/glotaran/model/model.py index 13c7cf24c..8d56c509c 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy +from typing import Any from typing import List from warnings import warn @@ -17,6 +18,7 @@ from glotaran.model.weight import Weight from glotaran.parameter import Parameter from glotaran.parameter import ParameterGroup +from glotaran.plugin_system.megacomplex_registration import get_megacomplex from glotaran.utils.ipython import MarkdownStr default_model_items = { @@ -56,18 +58,34 @@ def __init__( @classmethod def from_dict( cls, - model_dict: dict, + model_dict: dict[str, Any], *, - megacomplex_types: dict[str, type[Megacomplex]], + megacomplex_types: dict[str, type[Megacomplex]] | None = None, default_megacomplex_type: str | None = None, ) -> Model: """Creates a model from a dictionary. Parameters ---------- - model_dict : + model_dict: dict[str, Any] Dictionary containing the model. + megacomplex_types: dict[str, type[Megacomplex]] | None + Overwrite 'megacomplex_types' in ``model_dict`` for testing. + default_megacomplex_type: str | None + Overwrite 'default-megacomplex' in ``model_dict`` for testing. """ + if default_megacomplex_type is None: + default_megacomplex_type = model_dict.get("default-megacomplex") + + if megacomplex_types is None: + megacomplex_types = { + m["type"]: get_megacomplex(m["type"]) + for m in model_dict["megacomplex"].values() + if "type" in m + } + if default_megacomplex_type is not None: + megacomplex_types[default_megacomplex_type] = get_megacomplex(default_megacomplex_type) + model_dict.pop("default-megacomplex", None) model = cls( megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type diff --git a/glotaran/testing/__init__.py b/glotaran/testing/__init__.py new file mode 100644 index 000000000..a2b929333 --- /dev/null +++ b/glotaran/testing/__init__.py @@ -0,0 +1 @@ +"""Testing framework package for glotaran itself and plugins.""" diff --git a/glotaran/testing/model_generators.py b/glotaran/testing/model_generators.py new file mode 100644 index 000000000..5c59762b3 --- /dev/null +++ b/glotaran/testing/model_generators.py @@ -0,0 +1,311 @@ +"""Model generators used to generate simple models from a set of inputs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Literal + +from glotaran.model import Model +from glotaran.parameter.parameter_group import ParameterGroup + +if TYPE_CHECKING: + from glotaran.utils.ipython import MarkdownStr + + +def _split_iterable_in_non_dict_and_dict_items( + input_list: list[float, dict[str, bool | float]], +) -> tuple[list[float], list[dict[str, bool | float]]]: + """Split an iterable (list) into non-dict and dict items. + + Parameters + ---------- + input_list : list[float, dict[str, bool | float]] + A list of values of type `float` and a dict with parameter options, e.g. + `[1, 2, 3, {"vary": False, "non-negative": True}]` + + Returns + ------- + tuple[list[float], list[dict[str, bool | float]]] + Split a list into non-dict (`values`) and dict items (`defaults`), + return a tuple (`values`, `defaults`) + """ + values: list = [val for val in input_list if not isinstance(val, dict)] + defaults: list = [val for val in input_list if isinstance(val, dict)] + return values, defaults + + +@dataclass +class SimpleModelGenerator: + """A minimal boilerplate model and parameters generator. + + Generates a model (together with the parameters specification) based on + parameter input values assigned to the generator's attributes + """ + + rates: list[float] = field(default_factory=list) + """A list of values representing decay rates""" + k_matrix: Literal["parallel", "sequential"] | dict[tuple[str, str], str] = "parallel" + """"A `dict` with a k_matrix specification or `Literal["parallel", "sequential"]`""" + compartments: list[str] | None = None + """A list of compartment names""" + irf: dict[str, float] = field(default_factory=dict) + """A dict of items specifying an irf""" + initial_concentration: list[float] = field(default_factory=list) + """A list values representing the initial concentration""" + dispersion_coefficients: list[float] = field(default_factory=list) + """A list of values representing the dispersion coefficients""" + dispersion_center: float | None = None + """A value representing the dispersion center""" + default_megacomplex: str = "decay" + """The default_megacomplex identifier""" + # TODO: add support for a spectral model: + # shapes: list[float] = field(default_factory=list, init=False) + + @property + def valid(self) -> bool: + """Check if the generator state is valid. + + Returns + ------- + bool + Generator state obtained by calling the generated model's + `valid` function with the generated parameters as input. + """ + try: + return self.model.valid(parameters=self.parameters) + except ValueError: + return False + + def validate(self) -> str: + """Call `validate` on the generated model and return its output. + + Returns + ------- + str + A string listing problems in the generated model and parameters if any. + """ + return self.model.validate(parameters=self.parameters) + + @property + def model(self) -> Model: + """Return the generated model. + + Returns + ------- + Model + The generated model of type :class:`glotaran.model.Model`. + """ + return Model.from_dict(self.model_dict) + + @property + def model_dict(self) -> dict: + """Return a dict representation of the generated model. + + Returns + ------- + dict + A dict representation of the generated model. + """ + return self._model_dict() + + @property + def parameters(self) -> ParameterGroup: + """Return the generated parameters of type :class:`glotaran.parameter.ParameterGroup`. + + Returns + ------- + ParameterGroup + The generated parameters of type of type :class:`glotaran.parameter.ParameterGroup`. + """ + return ParameterGroup.from_dict(self.parameters_dict) + + @property + def parameters_dict(self) -> dict: + """Return a dict representation of the generated parameters. + + Returns + ------- + dict + A dict representing the generated parameters. + """ + return self._parameters_dict() + + @property + def model_and_parameters(self) -> tuple[Model, ParameterGroup]: + """Return generated model and parameters. + + Returns + ------- + tuple[Model, ParameterGroup] + A model of type :class:`glotaran.model.Model` and + and parameters of type :class:`glotaran.parameter.ParameterGroup`. + """ + return self.model, self.parameters + + @property + def _rates(self) -> tuple[list[float], list[dict[str, bool | float]]]: + """Validate input to rates, return a tuple of rates and parameter defaults. + + Returns + ------- + tuple[list[float], list[dict[str, bool | float]]] + A tuple of a list of rates and a dict containing parameter defaults + + Raises + ------ + ValueError + Raised if rates is not a list of at least one number. + """ + if not isinstance(self.rates, list): + raise ValueError(f"generator.rates: must be a `list`, got: {self.rates}") + if len(self.rates) == 0: + raise ValueError("generator.rates: must be a `list` with 1 or more rates") + if not isinstance(self.rates[0], (int, float)): + raise ValueError(f"generator.rates: 1st element must be numeric, got: {self.rates[0]}") + return _split_iterable_in_non_dict_and_dict_items(self.rates) + + def _parameters_dict_items(self) -> dict: + """Return a dict with items used in constructing the parameters. + + Returns + ------- + dict + A dict with items used in constructing a parameters dict. + """ + rates, rates_defaults = self._rates + items = {"rates": rates} + if rates_defaults: + items.update({"rates_defaults": rates_defaults[0]}) + items.update({"irf": [[key, value] for key, value in self.irf.items()]}) + if self.initial_concentration: + items.update({"inputs": self.initial_concentration}) + elif self.k_matrix == "parallel": + items.update( + { + "inputs": [ + ["1", 1], + {"vary": False}, + ] + } + ) + elif self.k_matrix == "sequential": + items.update( + { + "inputs": [ + ["1", 1], + ["0", 0], + {"vary": False}, + ] + } + ) + return items + + def _model_dict_items(self) -> dict: + """Return a dict with items used in constructing the model. + + Returns + ------- + dict + A dict with items used in constructing a model dict. + """ + rates, _ = self._rates + nr = len(rates) + indices = list(range(1, 1 + nr)) + items = {"default-megacomplex": self.default_megacomplex} + if self.irf: + items.update( + { + "irf": { + "type": "multi-gaussian", + "center": ["irf.center"], + "width": ["irf.width"], + } + } + ) + if isinstance(self.k_matrix, dict): + items.update({"k_matrix": self.k_matrix}) + items.update({"input_parameters": [f"inputs.{i}" for i in indices]}) + items.update({"compartments": [f"s{i}" for i in indices]}) + # TODO: get unique compartments from user defined k_matrix + if self.k_matrix == "parallel": + items.update({"input_parameters": ["inputs.1"] * nr}) + items.update({"k_matrix": {(f"s{i}", f"s{i}"): f"rates.{i}" for i in indices}}) + elif self.k_matrix == "sequential": + items.update({"input_parameters": ["inputs.1"] + ["inputs.0"] * (nr - 1)}) + items.update( + {"k_matrix": {(f"s{i if i==nr else i+1}", f"s{i}"): f"rates.{i}" for i in indices}} + ) + if self.k_matrix in ("parallel", "sequential"): + items.update({"compartments": [f"s{i}" for i in indices]}) + return items + + def _parameters_dict(self) -> dict: + """Return a parameters dict. + + Returns + ------- + dict + A dict that can be passed to the `ParameterGroup` `from_dict` method. + """ + items = self._parameters_dict_items() + rates = items["rates"] + if "rates_defaults" in items: + rates += [items["rates_defaults"]] + result = {"rates": rates} + if items["irf"]: + result.update({"irf": items["irf"]}) + result.update({"inputs": items["inputs"]}) + return result + + def _model_dict(self) -> dict: + """Return a model dict. + + Returns + ------- + dict + A dict that can be passed to the `Model` `from_dict` method. + """ + items = self._model_dict_items() + result = {"default-megacomplex": items["default-megacomplex"]} + result.update( + { + "initial_concentration": { + "j1": { + "compartments": items["compartments"], + "parameters": items["input_parameters"], + }, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": {"k1": {"matrix": items["k_matrix"]}}, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "megacomplex": ["mc1"], + }, + }, + } + ) + if "irf" in items: + result["dataset"]["dataset1"].update({"irf": "irf1"}) + result.update( + { + "irf": { + "irf1": items["irf"], + } + } + ) + return result + + def markdown(self) -> MarkdownStr: + """Return a markdown string representation of the generated model and parameters. + + Returns + ------- + MarkdownStr + A markdown string + """ + return self.model.markdown(parameters=self.parameters) diff --git a/glotaran/testing/plugin_system.py b/glotaran/testing/plugin_system.py new file mode 100644 index 000000000..1487eb1cd --- /dev/null +++ b/glotaran/testing/plugin_system.py @@ -0,0 +1,184 @@ +"""Mock functionality for the plugin system.""" +from __future__ import annotations + +from contextlib import ExitStack +from contextlib import contextmanager +from typing import TYPE_CHECKING +from unittest import mock + +from glotaran.plugin_system.base_registry import __PluginRegistry + +if TYPE_CHECKING: + from typing import Generator + from typing import MutableMapping + + from glotaran.io.interface import DataIoInterface + from glotaran.io.interface import ProjectIoInterface + from glotaran.model.megacomplex import Megacomplex + from glotaran.plugin_system.base_registry import _PluginType + + +@contextmanager +def _monkeypatch_plugin_registry( + register_name: str, + test_registry: MutableMapping[str, _PluginType] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Contextmanager to monkeypatch any Pluginregistry with name ``register_name``. + + Parameters + ---------- + register_name : str + Name of the register which should be patched. + test_registry : MutableMapping[str, _PluginType] + Registry to to update or replace the ``register_name`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_registry`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + + See Also + -------- + monkeypatch_plugin_registry_megacomplex + monkeypatch_plugin_registry_data_io + monkeypatch_plugin_registry_project_io + """ + if test_registry is not None: + initila_plugins = ( + __PluginRegistry.__dict__[register_name] if not create_new_registry else {} + ) + + with mock.patch.object( + __PluginRegistry, register_name, {**initila_plugins, **test_registry} + ): + yield + else: + yield + + +@contextmanager +def monkeypatch_plugin_registry_megacomplex( + test_megacomplex: MutableMapping[str, type[Megacomplex]] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`Megacomplex` registry. + + Parameters + ---------- + test_megacomplex : MutableMapping[str, type[Megacomplex]], optional + Registry to to update or replace the ``Megacomplex`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_megacomplex`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("megacomplex", test_megacomplex, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry_data_io( + test_data_io: MutableMapping[str, DataIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`DataIoInterface` registry. + + Parameters + ---------- + test_data_io : MutableMapping[str, DataIoInterface], optional + Registry to to update or replace the ``DataIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_data_io`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("data_io", test_data_io, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry_project_io( + test_project_io: MutableMapping[str, ProjectIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`ProjectIoInterface` registry. + + Parameters + ---------- + test_project_io : MutableMapping[str, ProjectIoInterface], optional + Registry to to update or replace the ``ProjectIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_data_io`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("project_io", test_project_io, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry( + *, + test_megacomplex: MutableMapping[str, type[Megacomplex]] | None = None, + test_data_io: MutableMapping[str, DataIoInterface] | None = None, + test_project_io: MutableMapping[str, ProjectIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Contextmanager to monkeypatch multiple plugin registries at once. + + Parameters + ---------- + test_megacomplex : MutableMapping[str, type[Megacomplex]], optional + Registry to to update or replace the ``Megacomplex`` registry with. + , by default None + test_data_io : MutableMapping[str, DataIoInterface], optional + Registry to to update or replace the ``DataIoInterface`` registry with. + , by default None + test_project_io : MutableMapping[str, ProjectIoInterface], optional + Registry to to update or replace the ``ProjectIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from the arguments. + , by default False + + Yields + ------ + Generator[None, None, None] + Just keeps all context manager alive + + See Also + -------- + monkeypatch_plugin_registry_megacomplex + monkeypatch_plugin_registry_data_io + monkeypatch_plugin_registry_project_io + """ + context_managers = [ + monkeypatch_plugin_registry_megacomplex(test_megacomplex, create_new_registry), + monkeypatch_plugin_registry_data_io(test_data_io, create_new_registry), + monkeypatch_plugin_registry_project_io(test_project_io, create_new_registry), + ] + + with ExitStack() as stack: + for context_manager in context_managers: + stack.enter_context(context_manager) + yield diff --git a/glotaran/testing/test/__init__.py b/glotaran/testing/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glotaran/testing/test/test_model_generators.py b/glotaran/testing/test/test_model_generators.py new file mode 100644 index 000000000..bd287f5cf --- /dev/null +++ b/glotaran/testing/test/test_model_generators.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from copy import deepcopy + +import pytest +from rich import pretty +from rich import print # pylint: disable=W0622 + +from glotaran.model import Model +from glotaran.parameter import ParameterGroup +from glotaran.testing.model_generators import SimpleModelGenerator + +pretty.install() + + +REF_PARAMETER_DICT = { + "rates": [ + ["1", 501e-3], + ["2", 202e-4], + ["3", 105e-5], + {"non-negative": True}, + ], + "irf": [["center", 1.3], ["width", 7.8]], + "inputs": [ + ["1", 1], + ["0", 0], + {"vary": False}, + ], +} + +REF_MODEL_DICT = { + "default-megacomplex": "decay", + "initial_concentration": { + "j1": { + "compartments": ["s1", "s2", "s3"], + "parameters": ["inputs.1", "inputs.0", "inputs.0"], + }, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "rates.1", + ("s3", "s2"): "rates.2", + ("s3", "s3"): "rates.3", + } + } + }, + "irf": { + "irf1": { + "type": "multi-gaussian", + "center": ["irf.center"], + "width": ["irf.width"], + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["mc1"], + }, + }, +} + + +def simple_diff_between_string(string1, string2): + return "".join(c2 for c1, c2 in zip(string1, string2) if c1 != c2) + + +def test_three_component_sequential_model(): + ref_model = Model.from_dict(deepcopy(REF_MODEL_DICT)) + ref_parameters = ParameterGroup.from_dict(deepcopy(REF_PARAMETER_DICT)) + generator = SimpleModelGenerator( + rates=[501e-3, 202e-4, 105e-5, {"non-negative": True}], + irf={"center": 1.3, "width": 7.8}, + k_matrix="sequential", + ) + for key, _ in REF_PARAMETER_DICT.items(): + assert key in generator.parameters_dict + # TODO: check contents + + model, parameters = generator.model_and_parameters + assert str(ref_model) == str(model), print( + simple_diff_between_string(str(model), str(ref_model)) + ) + assert str(ref_parameters) == str(parameters), print( + simple_diff_between_string(str(parameters), str(ref_parameters)) + ) + + +def test_only_rates_no_irf(): + generator = SimpleModelGenerator(rates=[0.1, 0.02, 0.003]) + assert "irf" not in generator.model_dict.keys() + + +def test_no_rates(): + generator = SimpleModelGenerator() + assert generator.valid is False + + +def test_one_rate(): + generator = SimpleModelGenerator([1]) + assert generator.valid is True + assert "is valid" in generator.validate() + + +def test_rates_not_a_list(): + generator = SimpleModelGenerator(1) + assert generator.valid is False + with pytest.raises(ValueError): + print(generator.validate()) + + +def test_set_rates_delayed(): + generator = SimpleModelGenerator() + generator.rates = [1, 2, 3] + assert generator.valid is True diff --git a/glotaran/testing/test/test_plugin_system.py b/glotaran/testing/test/test_plugin_system.py new file mode 100644 index 000000000..df1b2d9b7 --- /dev/null +++ b/glotaran/testing/test/test_plugin_system.py @@ -0,0 +1,91 @@ +import pytest + +from glotaran.io import DataIoInterface +from glotaran.io import ProjectIoInterface +from glotaran.model import Megacomplex +from glotaran.model import megacomplex +from glotaran.plugin_system.data_io_registration import known_data_formats +from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names +from glotaran.plugin_system.project_io_registration import known_project_formats +from glotaran.testing.plugin_system import monkeypatch_plugin_registry +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_data_io +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_megacomplex +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_project_io + + +@megacomplex(dimension="test") +class DummyMegacomplex(Megacomplex): + pass + + +class DummyDataIo(DataIoInterface): + pass + + +class DummyProjectIo(ProjectIoInterface): + pass + + +def test_monkeypatch_megacomplexes(): + """Megacomplex only added to registry while context is entered.""" + with monkeypatch_plugin_registry_megacomplex(test_megacomplex={"test_mc": DummyMegacomplex}): + assert "test_mc" in known_megacomplex_names() + + assert "test_mc" not in known_megacomplex_names() + with monkeypatch_plugin_registry(test_megacomplex={"test_full": DummyMegacomplex}): + assert "test_full" in known_megacomplex_names() + + assert "test_full" not in known_megacomplex_names() + + +def test_monkeypatch_data_io(): + """DataIoInterface only added to registry while context is entered.""" + with monkeypatch_plugin_registry_data_io( + test_data_io={"test_dio": DummyDataIo(format_name="test")} + ): + assert "test_dio" in known_data_formats() + + assert "test_mc" not in known_data_formats() + + with monkeypatch_plugin_registry(test_data_io={"test_full": DummyDataIo(format_name="test")}): + assert "test_full" in known_data_formats() + + assert "test_full" not in known_data_formats() + + +def test_monkeypatch_project_io(): + """ProjectIoInterface only added to registry while context is entered.""" + with monkeypatch_plugin_registry_project_io( + test_project_io={"test_pio": DummyProjectIo(format_name="test")} + ): + assert "test_pio" in known_project_formats() + + assert "test_pio" not in known_megacomplex_names() + with monkeypatch_plugin_registry( + test_project_io={"test_full": DummyProjectIo(format_name="test")} + ): + assert "test_full" in known_project_formats() + + assert "test_full" not in known_project_formats() + + +@pytest.mark.parametrize("create_new_registry", (True, False)) +def test_monkeypatch_plugin_registry_full(create_new_registry: bool): + """Create a completely new registry.""" + + assert "decay" in known_megacomplex_names() + assert "yml" in known_project_formats() + assert "sdt" in known_data_formats() + + with monkeypatch_plugin_registry( + test_megacomplex={"test_mc": DummyMegacomplex}, + test_project_io={"test_pio": DummyProjectIo(format_name="test")}, + test_data_io={"test_dio": DummyDataIo(format_name="test")}, + create_new_registry=create_new_registry, + ): + assert "test_mc" in known_megacomplex_names() + assert "test_pio" in known_project_formats() + assert "test_dio" in known_data_formats() + assert ("decay" not in known_megacomplex_names()) is create_new_registry + assert ("yml" not in known_project_formats()) is create_new_registry + assert ("sdt" not in known_data_formats()) is create_new_registry diff --git a/requirements_dev.txt b/requirements_dev.txt index e63b6b521..6df755111 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,6 +7,7 @@ asteval==0.9.25 numpy==1.21.1 scipy==1.7.0 click==8.0.1 +rich==10.9.0 numba==0.53.1 pandas==1.3.1 pyyaml==5.4.1 diff --git a/setup.cfg b/setup.cfg index 977820588..022bd304e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = numpy>=1.20.0 pandas>=0.25.2 pyyaml>=5.2 + rich>=10.9.0 scipy>=1.3.2 sdtfile>=2020.8.3 setuptools>=41.2