Skip to content

Commit

Permalink
🩹 False positive model validation fail when combining multiple defaul…
Browse files Browse the repository at this point in the history
…t megacomplexes (#797)

* 🩹 Fixed missing 'f' in front of f-string which rendered the error useless

* 🩹 Deactivate instance check if megacomplex type is None

Compatibility with legacy (0.4.0) model specs and combining multiple megacomplexes

* 🩹 Fixed usage of  'glotaran_unique' in 'ensure_unique_megacomplexes'

* 👌 Improved typing

* 🧪 Added unittests for 'ensure_unique_megacomplexes'

* 🩹 Fixed missing megacomplex definition not being reported as error

* 🩹 Fixed megacomplex type being None

* ♻️ Refactored 'ensure_unique_megacomplexes' using Counter and improved typing.
Since 'DatasetModel' is never instantiated from the class definition, but 'create_dataset_model_type' creates a new class in a closure each time a new instance is needed, values across instances won't be overwritten.

* ♻️ Moved typing information to stub file
After a personal request of @joernweissenborn  I removed the class attributes again and moved them to a pyi file

* 🧹 Removed warning about missing megacomplex definition
Since it is already reported as 'Missing Model Item'
  • Loading branch information
s-weigand authored and jsnel committed Sep 16, 2021
1 parent 90b9ce7 commit a9642a2
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 30 deletions.
78 changes: 48 additions & 30 deletions glotaran/model/dataset_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
"""The DatasetModel class."""
from __future__ import annotations

from collections import Counter
from typing import TYPE_CHECKING
from typing import Generator

import numpy as np
import xarray as xr

from glotaran.model.item import model_item
from glotaran.model.item import model_item_validator
from glotaran.parameter import Parameter

if TYPE_CHECKING:
from typing import Any
from typing import Generator
from typing import Hashable

from glotaran.model.megacomplex import Megacomplex
from glotaran.model.model import Model
from glotaran.parameter import Parameter


def create_dataset_model_type(properties: dict[str, any]) -> type:
def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]:
@model_item(properties=properties)
class ModelDatasetModel(DatasetModel):
pass
Expand All @@ -33,13 +37,17 @@ class DatasetModel:
parameter.
"""

def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]:
def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's megacomplexes."""
for i, megacomplex in enumerate(self.megacomplex):
scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None
yield scale, megacomplex

def iterate_global_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]:
def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's global megacomplexes."""
for i, megacomplex in enumerate(self.global_megacomplex):
scale = (
Expand All @@ -63,7 +71,7 @@ def get_model_dimension(self) -> str:
)
return self._model_dimension

def finalize_data(self, dataset: xr.Dataset):
def finalize_data(self, dataset: xr.Dataset) -> None:
is_full_model = self.has_global_model()
for megacomplex in self.megacomplex:
megacomplex.finalize_data(self, dataset, is_full_model=is_full_model)
Expand All @@ -73,7 +81,7 @@ def finalize_data(self, dataset: xr.Dataset):
self, dataset, is_full_model=is_full_model, as_global=True
)

def overwrite_model_dimension(self, model_dimension: str):
def overwrite_model_dimension(self, model_dimension: str) -> None:
"""Overwrites the dataset model's model dimension."""
self._model_dimension = model_dimension

Expand Down Expand Up @@ -104,11 +112,11 @@ def get_global_dimension(self) -> str:
)
return self._global_dimension

def overwrite_global_dimension(self, global_dimension: str):
def overwrite_global_dimension(self, global_dimension: str) -> None:
"""Overwrites the dataset model's global dimension."""
self._global_dimension = global_dimension

def swap_dimensions(self):
def swap_dimensions(self) -> None:
"""Swaps the dataset model's global and model dimension."""
global_dimension = self.get_model_dimension()
model_dimension = self.get_global_dimension()
Expand All @@ -117,9 +125,7 @@ def swap_dimensions(self):

def set_data(self, dataset: xr.Dataset) -> DatasetModel:
"""Sets the dataset model's data."""
self._coords: dict[str, np.ndarray] = {
name: dim.values for name, dim in dataset.coords.items()
}
self._coords = {name: dim.values for name, dim in dataset.coords.items()}
self._data: np.ndarray = dataset.data.values
self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None
if self._weight is not None:
Expand Down Expand Up @@ -152,7 +158,7 @@ def set_coordinates(self, coords: dict[str, np.ndarray]):
"""Sets the dataset model's coordinates."""
self._coords = coords

def get_coordinates(self) -> np.ndarray:
def get_coordinates(self) -> dict[Hashable, np.ndarray]:
"""Gets the dataset model's coordinates."""
return self._coords

Expand All @@ -166,20 +172,32 @@ def get_global_axis(self) -> np.ndarray:

@model_item_validator(False)
def ensure_unique_megacomplexes(self, model: Model) -> list[str]:

megacomplexes = [model.megacomplex[m] for m in self.megacomplex if m in model.megacomplex]
types = {type(m) for m in megacomplexes}
problems = []

for megacomplex_type in types:
if not megacomplex_type.glotaran_unique:
continue
instances = [m for m in megacomplexes if isinstance(m, megacomplex_type)]
n = len(instances)
if n != 1:
problems.append(
f"Multiple instances of unique megacomplex type '{instances[0].type}' "
"in dataset {self.label}"
)

return problems
"""Ensure that unique megacomplexes Are only used once per dataset.
Parameters
----------
model : Model
Model object using this dataset model.
Returns
-------
list[str]
Error messages to be shown when the model gets validated.
"""
glotaran_unique_megacomplex_types = []

for megacomplex_name in self.megacomplex:
try:
megacomplex_instance = model.megacomplex[megacomplex_name]
if type(megacomplex_instance).glotaran_unique() is True:
type_name = megacomplex_instance.type or megacomplex_instance.name
glotaran_unique_megacomplex_types.append(type_name)
except KeyError:
pass

return [
f"Multiple instances of unique megacomplex type {type_name!r} "
f"in dataset {self.label!r}"
for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common()
if count > 1
]
47 changes: 47 additions & 0 deletions glotaran/model/dataset_model.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import Any
from typing import Generator
from typing import Hashable

import numpy as np
import xarray as xr

from glotaran.model.megacomplex import Megacomplex
from glotaran.model.model import Model
from glotaran.parameter import Parameter

def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: ...

class DatasetModel:

label: str
megacomplex: list[str]
megacomplex_scale: list[Parameter] | None
global_megacomplex: list[str]
global_megacomplex_scale: list[Parameter] | None
scale: Parameter | None
_coords: dict[Hashable, np.ndarray]
def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ...
def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ...
def get_model_dimension(self) -> str: ...
def finalize_data(self, dataset: xr.Dataset) -> None: ...
def overwrite_model_dimension(self, model_dimension: str) -> None: ...
def get_global_dimension(self) -> str: ...
def overwrite_global_dimension(self, global_dimension: str) -> None: ...
def swap_dimensions(self) -> None: ...
def set_data(self, dataset: xr.Dataset) -> DatasetModel: ...
def get_data(self) -> np.ndarray: ...
def get_weight(self) -> np.ndarray | None: ...
def index_dependent(self) -> bool: ...
def overwrite_index_dependent(self, index_dependent: bool): ...
def has_global_model(self) -> bool: ...
def set_coordinates(self, coords: dict[str, np.ndarray]): ...
def get_coordinates(self) -> dict[Hashable, np.ndarray]: ...
def get_model_axis(self) -> np.ndarray: ...
def get_global_axis(self) -> np.ndarray: ...
def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ...
1 change: 1 addition & 0 deletions glotaran/model/megacomplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def decorator(cls):
megacomplex_type = model_item(properties=properties, has_type=True)(cls)

if register_as is not None:
megacomplex_type.name = register_as
register_megacomplex(register_as, megacomplex_type)

return megacomplex_type
Expand Down
88 changes: 88 additions & 0 deletions glotaran/model/test/test_dataset_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Tests for glotaran.model.dataset_model.DatasetModel"""
from __future__ import annotations

import pytest

from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex
from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex
from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex
from glotaran.builtin.megacomplexes.decay import DecayMegacomplex
from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex
from glotaran.model.dataset_model import create_dataset_model_type
from glotaran.model.model import default_dataset_properties


class MockModel:
"""Test Model only containing the megacomplex property.
Multiple and different kinds of megacomplexes are defined
but only a subset will be used by the DatsetModel.
"""

def __init__(self) -> None:
self.megacomplex = {
# not unique
"d1": DecayMegacomplex(),
"d2": DecayMegacomplex(),
"d3": DecayMegacomplex(),
"s1": SpectralMegacomplex(),
"s2": SpectralMegacomplex(),
"s3": SpectralMegacomplex(),
"doa1": DampedOscillationMegacomplex(),
"doa2": DampedOscillationMegacomplex(),
# unique
"b1": BaselineMegacomplex(),
"b2": BaselineMegacomplex(),
"c1": CoherentArtifactMegacomplex(),
"c2": CoherentArtifactMegacomplex(),
}


@pytest.mark.parametrize(
"used_megacomplexes, expected_problems",
(
(
["d1"],
[],
),
(
["d1", "d2", "d3"],
[],
),
(
["s1", "s2", "s3"],
[],
),
(
["d1", "d2", "d3", "s1", "s2", "s3", "doa1", "doa2", "b1", "c1"],
[],
),
(
["d1", "b1", "b2"],
["Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'"],
),
(
["d1", "c1", "c2"],
["Multiple instances of unique megacomplex type 'coherent-artifact' in dataset 'ds1'"],
),
(
["d1", "b1", "b2", "c1", "c2"],
[
"Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'",
"Multiple instances of unique megacomplex type "
"'coherent-artifact' in dataset 'ds1'",
],
),
),
)
def test_datasetmodel_ensure_unique_megacomplexes(
used_megacomplexes: list[str], expected_problems: list[str]
):
"""Only report problems if multiple unique megacomplexes of the same type are used."""
dataset_model = create_dataset_model_type({**default_dataset_properties})()
dataset_model.megacomplex = used_megacomplexes # type:ignore
dataset_model.label = "ds1" # type:ignore
problems = dataset_model.ensure_unique_megacomplexes(MockModel()) # type:ignore

assert len(problems) == len(expected_problems)
assert problems == expected_problems

0 comments on commit a9642a2

Please sign in to comment.