Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add clp guidance megacomplex (Sourcery refactored) #1031

Closed
wants to merge 8 commits into from
2 changes: 2 additions & 0 deletions glotaran/analysis/optimization_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def create_result_dataset(
dataset_model = self.dataset_models[label]
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
dataset.attrs["global_dimension"] = global_dimension
dataset.attrs["model_dimension"] = model_dimension
if copy:
dataset = dataset.copy()
if dataset_model.is_index_dependent():
Expand Down
1 change: 1 addition & 0 deletions glotaran/builtin/megacomplexes/clp_guide/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from glotaran.builtin.megacomplexes.clp_guide.clp_guide_megacomplex import ClpGuideMegacomplex
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import numpy as np
import xarray as xr

from glotaran.model import DatasetModel
from glotaran.model import Megacomplex
from glotaran.model import megacomplex


@megacomplex(exclusive=True, register_as="clp-guide", properties={"target": str})
class ClpGuideMegacomplex(Megacomplex):
def calculate_matrix(
self,
dataset_model: DatasetModel,
indices: dict[str, int],
**kwargs,
):
clp_label = [self.target]
matrix = np.ones((1, 1), dtype=np.float64)
return clp_label, matrix

def index_dependent(self, dataset_model: DatasetModel) -> bool:
return False

def finalize_data(
self,
dataset_model: DatasetModel,
dataset: xr.Dataset,
is_full_model: bool = False,
as_global: bool = False,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np

from glotaran.analysis.optimize import optimize
from glotaran.analysis.simulation import simulate
from glotaran.builtin.megacomplexes.clp_guide import ClpGuideMegacomplex
from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex
from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import create_gaussian_clp
from glotaran.model import Model
from glotaran.parameter import ParameterGroup
from glotaran.project import Scheme


def test_clp_guide():

model = Model.from_dict(
{
"dataset_groups": {"default": {"link_clp": True}},
"megacomplex": {
"mc1": {
"type": "decay-sequential",
"compartments": ["s1", "s2"],
"rates": ["1", "2"],
},
"mc2": {"type": "clp-guide", "dimension": "time", "target": "s1"},
},
"dataset": {
"dataset1": {"megacomplex": ["mc1"]},
"dataset2": {"megacomplex": ["mc2"]},
},
},
megacomplex_types={
"decay-sequential": DecaySequentialMegacomplex,
"clp-guide": ClpGuideMegacomplex,
},
)

initial_parameters = ParameterGroup.from_list(
[101e-5, 501e-4, [1, {"vary": False, "non-negative": False}]]
)
wanted_parameters = ParameterGroup.from_list(
[101e-4, 501e-3, [1, {"vary": False, "non-negative": False}]]
)

time = np.arange(0, 50, 1.5)
pixel = np.arange(600, 750, 5)
axis = {"time": time, "pixel": pixel}

clp = create_gaussian_clp(["s1", "s2"], [7, 30], [620, 720], [10, 50], pixel)

dataset1 = simulate(model, "dataset1", wanted_parameters, axis, clp)
dataset2 = clp.sel(clp_label=["s1"]).rename(clp_label="time")
data = {"dataset1": dataset1, "dataset2": dataset2}

scheme = Scheme(
model=model,
parameters=initial_parameters,
data=data,
maximum_number_function_evaluations=20,
)
result = optimize(scheme)
print(result.optimized_parameters)

for label, param in result.optimized_parameters.all():
assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from glotaran.project import Scheme


def _create_gaussian_clp(labels, amplitudes, centers, widths, axis):
def create_gaussian_clp(labels, amplitudes, centers, widths, axis):
return xr.DataArray(
[
amplitudes[i] * np.exp(-np.log(2) * np.square(2 * (axis - centers[i]) / widths[i]))
Expand Down Expand Up @@ -179,9 +179,7 @@ class ThreeComponentParallel:

axis = {"time": time, "pixel": pixel}

clp = _create_gaussian_clp(
["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel
)
clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel)


class ThreeComponentSequential:
Expand Down Expand Up @@ -240,9 +238,7 @@ class ThreeComponentSequential:
pixel = np.arange(600, 750, 10)
axis = {"time": time, "pixel": pixel}

clp = _create_gaussian_clp(
["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel
)
clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel)


@pytest.mark.parametrize(
Expand Down
88 changes: 70 additions & 18 deletions glotaran/model/dataset_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ class DatasetModel:

def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's megacomplexes."""
) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]:
"""Iterates the dataset model's megacomplexes."""
for i, megacomplex in enumerate(self.megacomplex):
scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None
yield scale, megacomplex

def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's global megacomplexes."""
) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]:
"""Iterates the dataset model's global megacomplexes."""
for i, megacomplex in enumerate(self.global_megacomplex):
scale = (
self.global_megacomplex_scale[i]
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_global_axis(self) -> np.ndarray:

@model_item_validator(False)
def ensure_unique_megacomplexes(self, model: Model) -> list[str]:
"""Ensure that unique megacomplexes Are only used once per dataset.
"""Ensure that unique megacomplexes are only used once per dataset.

Parameters
----------
Expand All @@ -184,20 +184,72 @@ def ensure_unique_megacomplexes(self, model: Model) -> list[str]:
list[str]
Error messages to be shown when the model gets validated.
"""
glotaran_unique_megacomplex_types = []
errors = []

def get_unique_errors(megacomplexes: list[str], is_global: bool) -> list[str]:
unique_types = []
for megacomplex_name in megacomplexes:
try:
megacomplex_instance = model.megacomplex[megacomplex_name]
if type(megacomplex_instance).glotaran_unique():
type_name = megacomplex_instance.type or megacomplex_instance.name
unique_types.append(type_name)
except KeyError:
# The megacomplex does not exist, the model validator will report this
pass
this_errors = [
f"Multiple instances of unique{' global ' if is_global else ' '}"
f"megacomplex type {type_name!r} in dataset {self.label!r}"
for type_name, count in Counter(unique_types).most_common()
if count > 1
]

return this_errors

if self.megacomplex:
errors += get_unique_errors(self.megacomplex, False)
if self.global_megacomplex:
errors += get_unique_errors(self.global_megacomplex, True)

return errors

for megacomplex_name in self.megacomplex:
@model_item_validator(False)
def ensure_exclusive_megacomplexes(self, model: Model) -> list[str]:
"""Ensure that exclusive megacomplexes are the only megacomplex in the dataset model.

Parameters
----------
model : Model
Model object using this dataset model.

Returns
-------
list[str]
Error messages to be shown when the model gets validated.
"""

errors = []

def get_exclusive_errors(megacomplexes: list[str], is_global: bool) -> str:
try:
megacomplex_instance = model.megacomplex[megacomplex_name]
if type(megacomplex_instance).glotaran_unique() is True:
type_name = megacomplex_instance.type or megacomplex_instance.name
glotaran_unique_megacomplex_types.append(type_name)
except KeyError:
exclusive_megacomplex = next(
model.megacomplex[label]
for label in megacomplexes
if label in model.megacomplex
and type(model.megacomplex[label]).glotaran_exclusive()
)
if len(self.megacomplex) != 1:
return [
f"Megacomplex '{type(exclusive_megacomplex)}' is exclusive and cannot be "
f"combined with other megacomplex in dataset model '{self.label}'."
]
except StopIteration:
pass
return []

if self.megacomplex:
errors += get_exclusive_errors(self.megacomplex, False)
if self.global_megacomplex:
errors += get_exclusive_errors(self.global_megacomplex, True)

return [
f"Multiple instances of unique megacomplex type {type_name!r} "
f"in dataset {self.label!r}"
for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common()
if count > 1
]
return errors
6 changes: 6 additions & 0 deletions glotaran/model/megacomplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def megacomplex(
dataset_model_items: dict[str, dict[str, Any]] = None,
dataset_properties: Any | dict[str, dict[str, Any]] = None,
unique: bool = False,
exclusive: bool = False,
register_as: str | None = None,
):
"""The `@megacomplex` decorator is intended to be used on subclasses of
Expand Down Expand Up @@ -67,6 +68,7 @@ def decorator(cls):
setattr(cls, "_glotaran_megacomplex_dataset_model_items", dataset_model_items)
setattr(cls, "_glotaran_megacomplex_dataset_properties", dataset_properties)
setattr(cls, "_glotaran_megacomplex_unique", unique)
setattr(cls, "_glotaran_megacomplex_exclusive", exclusive)

megacomplex_type = model_item(properties=properties, has_type=True)(cls)

Expand Down Expand Up @@ -140,3 +142,7 @@ def glotaran_dataset_properties(cls) -> str:
@classmethod
def glotaran_unique(cls) -> bool:
return cls._glotaran_megacomplex_unique

@classmethod
def glotaran_exclusive(cls) -> bool:
return cls._glotaran_megacomplex_exclusive
25 changes: 12 additions & 13 deletions glotaran/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def _add_dict_items(self, item_name: str, items: dict):

for label, item in items.items():
item_cls = self.model_items[item_name]
is_typed = hasattr(item_cls, "_glotaran_model_item_typed")
if is_typed:
if hasattr(item_cls, "_glotaran_model_item_typed"):
if "type" not in item and item_cls.get_default_type() is None:
raise ValueError(f"Missing type for attribute '{item_name}'")
item_type = item.get("type", item_cls.get_default_type())
Expand All @@ -156,8 +155,7 @@ def _add_list_items(self, item_name: str, items: list):

for item in items:
item_cls = self.model_items[item_name]
is_typed = hasattr(item_cls, "_glotaran_model_item_typed")
if is_typed:
if hasattr(item_cls, "_glotaran_model_item_typed"):
if "type" not in item:
raise ValueError(f"Missing type for attribute '{item_name}'")
item_type = item["type"]
Expand Down Expand Up @@ -212,15 +210,17 @@ def _add_model_item(self, item_name: str, item: type):
def _add_dataset_property(self, property_name: str, dataset_property: dict[str, any]):
if property_name in self._dataset_properties:
known_type = (
self._dataset_properties[property_name]
if not isinstance(self._dataset_properties, dict)
else self._dataset_properties[property_name]["type"]
self._dataset_properties[property_name]["type"]
if isinstance(self._dataset_properties, dict)
else self._dataset_properties[property_name]
)

new_type = (
dataset_property
if not isinstance(dataset_property, dict)
else dataset_property["type"]
dataset_property["type"]
if isinstance(dataset_property, dict)
else dataset_property
)

if known_type != new_type:
raise ModelError(
f"Cannot add dataset property of type {property_name} as it was "
Expand Down Expand Up @@ -357,7 +357,7 @@ def problem_list(self, parameters: ParameterGroup = None) -> list[str]:
for item in items:
problems += item.validate(self, parameters=parameters)
else:
for _, item in items.items():
for item in items.values():
problems += item.validate(self, parameters=parameters)

return problems
Expand All @@ -374,8 +374,7 @@ def validate(self, parameters: ParameterGroup = None, raise_exception: bool = Fa
"""
result = ""

problems = self.problem_list(parameters)
if problems:
if problems := self.problem_list(parameters):
result = f"Your model has {len(problems)} problems:\n"
for p in problems:
result += f"\n * {p}"
Expand Down
Loading