Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 False positive model validation fail when combining multiple default megacomplexes #797

Merged
78 changes: 48 additions & 30 deletions glotaran/model/dataset_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
"""The DatasetModel class."""
from __future__ import annotations

from collections import Counter
from typing import TYPE_CHECKING
from typing import Generator

import numpy as np
import xarray as xr

from glotaran.model.item import model_item
from glotaran.model.item import model_item_validator
from glotaran.parameter import Parameter

if TYPE_CHECKING:
from typing import Any
from typing import Generator
from typing import Hashable

from glotaran.model.megacomplex import Megacomplex
from glotaran.model.model import Model
from glotaran.parameter import Parameter


def create_dataset_model_type(properties: dict[str, any]) -> type:
def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]:
@model_item(properties=properties)
class ModelDatasetModel(DatasetModel):
pass
Expand All @@ -33,13 +37,17 @@ class DatasetModel:
parameter.
"""

def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]:
def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's megacomplexes."""
for i, megacomplex in enumerate(self.megacomplex):
scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None
yield scale, megacomplex

def iterate_global_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]:
def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's global megacomplexes."""
for i, megacomplex in enumerate(self.global_megacomplex):
scale = (
Expand All @@ -63,7 +71,7 @@ def get_model_dimension(self) -> str:
)
return self._model_dimension

def finalize_data(self, dataset: xr.Dataset):
def finalize_data(self, dataset: xr.Dataset) -> None:
is_full_model = self.has_global_model()
for megacomplex in self.megacomplex:
megacomplex.finalize_data(self, dataset, is_full_model=is_full_model)
Expand All @@ -73,7 +81,7 @@ def finalize_data(self, dataset: xr.Dataset):
self, dataset, is_full_model=is_full_model, as_global=True
)

def overwrite_model_dimension(self, model_dimension: str):
def overwrite_model_dimension(self, model_dimension: str) -> None:
"""Overwrites the dataset model's model dimension."""
self._model_dimension = model_dimension

Expand Down Expand Up @@ -104,11 +112,11 @@ def get_global_dimension(self) -> str:
)
return self._global_dimension

def overwrite_global_dimension(self, global_dimension: str):
def overwrite_global_dimension(self, global_dimension: str) -> None:
"""Overwrites the dataset model's global dimension."""
self._global_dimension = global_dimension

def swap_dimensions(self):
def swap_dimensions(self) -> None:
"""Swaps the dataset model's global and model dimension."""
global_dimension = self.get_model_dimension()
model_dimension = self.get_global_dimension()
Expand All @@ -117,9 +125,7 @@ def swap_dimensions(self):

def set_data(self, dataset: xr.Dataset) -> DatasetModel:
"""Sets the dataset model's data."""
self._coords: dict[str, np.ndarray] = {
name: dim.values for name, dim in dataset.coords.items()
}
self._coords = {name: dim.values for name, dim in dataset.coords.items()}
self._data: np.ndarray = dataset.data.values
self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None
if self._weight is not None:
Expand Down Expand Up @@ -152,7 +158,7 @@ def set_coordinates(self, coords: dict[str, np.ndarray]):
"""Sets the dataset model's coordinates."""
self._coords = coords

def get_coordinates(self) -> np.ndarray:
def get_coordinates(self) -> dict[Hashable, np.ndarray]:
"""Gets the dataset model's coordinates."""
return self._coords

Expand All @@ -166,20 +172,32 @@ def get_global_axis(self) -> np.ndarray:

@model_item_validator(False)
def ensure_unique_megacomplexes(self, model: Model) -> list[str]:

megacomplexes = [model.megacomplex[m] for m in self.megacomplex if m in model.megacomplex]
types = {type(m) for m in megacomplexes}
problems = []

for megacomplex_type in types:
if not megacomplex_type.glotaran_unique:
continue
instances = [m for m in megacomplexes if isinstance(m, megacomplex_type)]
n = len(instances)
if n != 1:
problems.append(
f"Multiple instances of unique megacomplex type '{instances[0].type}' "
"in dataset {self.label}"
)

return problems
"""Ensure that unique megacomplexes Are only used once per dataset.
Parameters
----------
model : Model
Model object using this dataset model.
Returns
-------
list[str]
Error messages to be shown when the model gets validated.
"""
glotaran_unique_megacomplex_types = []

for megacomplex_name in self.megacomplex:
try:
megacomplex_instance = model.megacomplex[megacomplex_name]
if type(megacomplex_instance).glotaran_unique() is True:
type_name = megacomplex_instance.type or megacomplex_instance.name
glotaran_unique_megacomplex_types.append(type_name)
except KeyError:
pass

return [
f"Multiple instances of unique megacomplex type {type_name!r} "
f"in dataset {self.label!r}"
for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common()
if count > 1
]
47 changes: 47 additions & 0 deletions glotaran/model/dataset_model.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import Any
from typing import Generator
from typing import Hashable

import numpy as np
import xarray as xr

from glotaran.model.megacomplex import Megacomplex
from glotaran.model.model import Model
from glotaran.parameter import Parameter

def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: ...

class DatasetModel:

label: str
megacomplex: list[str]
megacomplex_scale: list[Parameter] | None
global_megacomplex: list[str]
global_megacomplex_scale: list[Parameter] | None
scale: Parameter | None
_coords: dict[Hashable, np.ndarray]
def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ...
def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ...
def get_model_dimension(self) -> str: ...
def finalize_data(self, dataset: xr.Dataset) -> None: ...
def overwrite_model_dimension(self, model_dimension: str) -> None: ...
def get_global_dimension(self) -> str: ...
def overwrite_global_dimension(self, global_dimension: str) -> None: ...
def swap_dimensions(self) -> None: ...
def set_data(self, dataset: xr.Dataset) -> DatasetModel: ...
def get_data(self) -> np.ndarray: ...
def get_weight(self) -> np.ndarray | None: ...
def index_dependent(self) -> bool: ...
def overwrite_index_dependent(self, index_dependent: bool): ...
def has_global_model(self) -> bool: ...
def set_coordinates(self, coords: dict[str, np.ndarray]): ...
def get_coordinates(self) -> dict[Hashable, np.ndarray]: ...
def get_model_axis(self) -> np.ndarray: ...
def get_global_axis(self) -> np.ndarray: ...
def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ...
1 change: 1 addition & 0 deletions glotaran/model/megacomplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def decorator(cls):
megacomplex_type = model_item(properties=properties, has_type=True)(cls)

if register_as is not None:
megacomplex_type.name = register_as
register_megacomplex(register_as, megacomplex_type)

return megacomplex_type
Expand Down
88 changes: 88 additions & 0 deletions glotaran/model/test/test_dataset_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Tests for glotaran.model.dataset_model.DatasetModel"""
from __future__ import annotations

import pytest

from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex
from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex
from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex
from glotaran.builtin.megacomplexes.decay import DecayMegacomplex
from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex
from glotaran.model.dataset_model import create_dataset_model_type
from glotaran.model.model import default_dataset_properties


class MockModel:
"""Test Model only containing the megacomplex property.

Multiple and different kinds of megacomplexes are defined
but only a subset will be used by the DatsetModel.
"""

def __init__(self) -> None:
self.megacomplex = {
# not unique
"d1": DecayMegacomplex(),
"d2": DecayMegacomplex(),
"d3": DecayMegacomplex(),
"s1": SpectralMegacomplex(),
"s2": SpectralMegacomplex(),
"s3": SpectralMegacomplex(),
"doa1": DampedOscillationMegacomplex(),
"doa2": DampedOscillationMegacomplex(),
# unique
"b1": BaselineMegacomplex(),
"b2": BaselineMegacomplex(),
"c1": CoherentArtifactMegacomplex(),
"c2": CoherentArtifactMegacomplex(),
}


@pytest.mark.parametrize(
"used_megacomplexes, expected_problems",
(
(
["d1"],
[],
),
(
["d1", "d2", "d3"],
[],
),
(
["s1", "s2", "s3"],
[],
),
(
["d1", "d2", "d3", "s1", "s2", "s3", "doa1", "doa2", "b1", "c1"],
[],
),
(
["d1", "b1", "b2"],
["Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'"],
),
(
["d1", "c1", "c2"],
["Multiple instances of unique megacomplex type 'coherent-artifact' in dataset 'ds1'"],
),
(
["d1", "b1", "b2", "c1", "c2"],
[
"Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'",
"Multiple instances of unique megacomplex type "
"'coherent-artifact' in dataset 'ds1'",
],
),
),
)
def test_datasetmodel_ensure_unique_megacomplexes(
used_megacomplexes: list[str], expected_problems: list[str]
):
"""Only report problems if multiple unique megacomplexes of the same type are used."""
dataset_model = create_dataset_model_type({**default_dataset_properties})()
dataset_model.megacomplex = used_megacomplexes # type:ignore
dataset_model.label = "ds1" # type:ignore
problems = dataset_model.ensure_unique_megacomplexes(MockModel()) # type:ignore

assert len(problems) == len(expected_problems)
assert problems == expected_problems