diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index d86a99ad3..2290ff3c9 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -1,22 +1,26 @@ """The DatasetModel class.""" from __future__ import annotations +from collections import Counter from typing import TYPE_CHECKING -from typing import Generator import numpy as np import xarray as xr from glotaran.model.item import model_item from glotaran.model.item import model_item_validator -from glotaran.parameter import Parameter if TYPE_CHECKING: + from typing import Any + from typing import Generator + from typing import Hashable + from glotaran.model.megacomplex import Megacomplex from glotaran.model.model import Model + from glotaran.parameter import Parameter -def create_dataset_model_type(properties: dict[str, any]) -> type: +def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: @model_item(properties=properties) class ModelDatasetModel(DatasetModel): pass @@ -33,13 +37,17 @@ class DatasetModel: parameter. """ - def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + def iterate_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's megacomplexes.""" for i, megacomplex in enumerate(self.megacomplex): scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None yield scale, megacomplex - def iterate_global_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + def iterate_global_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's global megacomplexes.""" for i, megacomplex in enumerate(self.global_megacomplex): scale = ( @@ -63,7 +71,7 @@ def get_model_dimension(self) -> str: ) return self._model_dimension - def finalize_data(self, dataset: xr.Dataset): + def finalize_data(self, dataset: xr.Dataset) -> None: is_full_model = self.has_global_model() for megacomplex in self.megacomplex: megacomplex.finalize_data(self, dataset, is_full_model=is_full_model) @@ -73,7 +81,7 @@ def finalize_data(self, dataset: xr.Dataset): self, dataset, is_full_model=is_full_model, as_global=True ) - def overwrite_model_dimension(self, model_dimension: str): + def overwrite_model_dimension(self, model_dimension: str) -> None: """Overwrites the dataset model's model dimension.""" self._model_dimension = model_dimension @@ -104,11 +112,11 @@ def get_global_dimension(self) -> str: ) return self._global_dimension - def overwrite_global_dimension(self, global_dimension: str): + def overwrite_global_dimension(self, global_dimension: str) -> None: """Overwrites the dataset model's global dimension.""" self._global_dimension = global_dimension - def swap_dimensions(self): + def swap_dimensions(self) -> None: """Swaps the dataset model's global and model dimension.""" global_dimension = self.get_model_dimension() model_dimension = self.get_global_dimension() @@ -117,9 +125,7 @@ def swap_dimensions(self): def set_data(self, dataset: xr.Dataset) -> DatasetModel: """Sets the dataset model's data.""" - self._coords: dict[str, np.ndarray] = { - name: dim.values for name, dim in dataset.coords.items() - } + self._coords = {name: dim.values for name, dim in dataset.coords.items()} self._data: np.ndarray = dataset.data.values self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None if self._weight is not None: @@ -152,7 +158,7 @@ def set_coordinates(self, coords: dict[str, np.ndarray]): """Sets the dataset model's coordinates.""" self._coords = coords - def get_coordinates(self) -> np.ndarray: + def get_coordinates(self) -> dict[Hashable, np.ndarray]: """Gets the dataset model's coordinates.""" return self._coords @@ -166,20 +172,32 @@ def get_global_axis(self) -> np.ndarray: @model_item_validator(False) def ensure_unique_megacomplexes(self, model: Model) -> list[str]: - - megacomplexes = [model.megacomplex[m] for m in self.megacomplex if m in model.megacomplex] - types = {type(m) for m in megacomplexes} - problems = [] - - for megacomplex_type in types: - if not megacomplex_type.glotaran_unique: - continue - instances = [m for m in megacomplexes if isinstance(m, megacomplex_type)] - n = len(instances) - if n != 1: - problems.append( - f"Multiple instances of unique megacomplex type '{instances[0].type}' " - "in dataset {self.label}" - ) - - return problems + """Ensure that unique megacomplexes Are only used once per dataset. + + Parameters + ---------- + model : Model + Model object using this dataset model. + + Returns + ------- + list[str] + Error messages to be shown when the model gets validated. + """ + glotaran_unique_megacomplex_types = [] + + for megacomplex_name in self.megacomplex: + try: + megacomplex_instance = model.megacomplex[megacomplex_name] + if type(megacomplex_instance).glotaran_unique() is True: + type_name = megacomplex_instance.type or megacomplex_instance.name + glotaran_unique_megacomplex_types.append(type_name) + except KeyError: + pass + + return [ + f"Multiple instances of unique megacomplex type {type_name!r} " + f"in dataset {self.label!r}" + for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common() + if count > 1 + ] diff --git a/glotaran/model/dataset_model.pyi b/glotaran/model/dataset_model.pyi new file mode 100644 index 000000000..c0a7b49f7 --- /dev/null +++ b/glotaran/model/dataset_model.pyi @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Any +from typing import Generator +from typing import Hashable + +import numpy as np +import xarray as xr + +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.model import Model +from glotaran.parameter import Parameter + +def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: ... + +class DatasetModel: + + label: str + megacomplex: list[str] + megacomplex_scale: list[Parameter] | None + global_megacomplex: list[str] + global_megacomplex_scale: list[Parameter] | None + scale: Parameter | None + _coords: dict[Hashable, np.ndarray] + def iterate_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... + def iterate_global_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... + def get_model_dimension(self) -> str: ... + def finalize_data(self, dataset: xr.Dataset) -> None: ... + def overwrite_model_dimension(self, model_dimension: str) -> None: ... + def get_global_dimension(self) -> str: ... + def overwrite_global_dimension(self, global_dimension: str) -> None: ... + def swap_dimensions(self) -> None: ... + def set_data(self, dataset: xr.Dataset) -> DatasetModel: ... + def get_data(self) -> np.ndarray: ... + def get_weight(self) -> np.ndarray | None: ... + def index_dependent(self) -> bool: ... + def overwrite_index_dependent(self, index_dependent: bool): ... + def has_global_model(self) -> bool: ... + def set_coordinates(self, coords: dict[str, np.ndarray]): ... + def get_coordinates(self) -> dict[Hashable, np.ndarray]: ... + def get_model_axis(self) -> np.ndarray: ... + def get_global_axis(self) -> np.ndarray: ... + def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ... diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 471dcf4ac..e8ae6afc1 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -70,6 +70,7 @@ def decorator(cls): megacomplex_type = model_item(properties=properties, has_type=True)(cls) if register_as is not None: + megacomplex_type.name = register_as register_megacomplex(register_as, megacomplex_type) return megacomplex_type diff --git a/glotaran/model/test/test_dataset_model.py b/glotaran/model/test/test_dataset_model.py new file mode 100644 index 000000000..ac89ab74d --- /dev/null +++ b/glotaran/model/test/test_dataset_model.py @@ -0,0 +1,88 @@ +"""Tests for glotaran.model.dataset_model.DatasetModel""" +from __future__ import annotations + +import pytest + +from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex +from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex +from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model.dataset_model import create_dataset_model_type +from glotaran.model.model import default_dataset_properties + + +class MockModel: + """Test Model only containing the megacomplex property. + + Multiple and different kinds of megacomplexes are defined + but only a subset will be used by the DatsetModel. + """ + + def __init__(self) -> None: + self.megacomplex = { + # not unique + "d1": DecayMegacomplex(), + "d2": DecayMegacomplex(), + "d3": DecayMegacomplex(), + "s1": SpectralMegacomplex(), + "s2": SpectralMegacomplex(), + "s3": SpectralMegacomplex(), + "doa1": DampedOscillationMegacomplex(), + "doa2": DampedOscillationMegacomplex(), + # unique + "b1": BaselineMegacomplex(), + "b2": BaselineMegacomplex(), + "c1": CoherentArtifactMegacomplex(), + "c2": CoherentArtifactMegacomplex(), + } + + +@pytest.mark.parametrize( + "used_megacomplexes, expected_problems", + ( + ( + ["d1"], + [], + ), + ( + ["d1", "d2", "d3"], + [], + ), + ( + ["s1", "s2", "s3"], + [], + ), + ( + ["d1", "d2", "d3", "s1", "s2", "s3", "doa1", "doa2", "b1", "c1"], + [], + ), + ( + ["d1", "b1", "b2"], + ["Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'"], + ), + ( + ["d1", "c1", "c2"], + ["Multiple instances of unique megacomplex type 'coherent-artifact' in dataset 'ds1'"], + ), + ( + ["d1", "b1", "b2", "c1", "c2"], + [ + "Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'", + "Multiple instances of unique megacomplex type " + "'coherent-artifact' in dataset 'ds1'", + ], + ), + ), +) +def test_datasetmodel_ensure_unique_megacomplexes( + used_megacomplexes: list[str], expected_problems: list[str] +): + """Only report problems if multiple unique megacomplexes of the same type are used.""" + dataset_model = create_dataset_model_type({**default_dataset_properties})() + dataset_model.megacomplex = used_megacomplexes # type:ignore + dataset_model.label = "ds1" # type:ignore + problems = dataset_model.ensure_unique_megacomplexes(MockModel()) # type:ignore + + assert len(problems) == len(expected_problems) + assert problems == expected_problems