Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨Make guidance spectra a first class citizen #983

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_baseline():
pixel = np.asarray([0])
coords = {"time": time, "pixel": pixel}
dataset_model = model.dataset["dataset1"].fill(model, parameter)
dataset_model.overwrite_global_dimension("pixel")
dataset_model.set_global_dimension("pixel")
dataset_model.set_coordinates(coords)
matrix = calculate_matrix(dataset_model, {})
compartments = matrix.clp_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_coherent_artifact(spectral_dependence: str):
coords = {"time": time, "spectral": spectral}

dataset_model = model.dataset["dataset1"].fill(model, parameters)
dataset_model.overwrite_global_dimension("spectral")
dataset_model.set_global_dimension("spectral")
dataset_model.set_coordinates(coords)
matrix = calculate_matrix(dataset_model, {"spectral": 1})
compartments = matrix.clp_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class OneCompartmentModelInvertedAxis:
axis = {"time": time, "spectral": spectral}

decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters)
decay_dataset_model.overwrite_global_dimension("spectral")
decay_dataset_model.set_global_dimension("spectral")
decay_dataset_model.set_coordinates(axis)
matrix = calculate_matrix(decay_dataset_model, {})
decay_compartments = matrix.clp_labels
Expand Down Expand Up @@ -157,7 +157,7 @@ class OneCompartmentModelNegativeSkew:
axis = {"time": time, "spectral": spectral}

decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters)
decay_dataset_model.overwrite_global_dimension("spectral")
decay_dataset_model.set_global_dimension("spectral")
decay_dataset_model.set_coordinates(axis)
matrix = calculate_matrix(decay_dataset_model, {})
decay_compartments = matrix.clp_labels
Expand Down Expand Up @@ -261,7 +261,7 @@ class ThreeCompartmentModel:
axis = {"time": time, "spectral": spectral}

decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters)
decay_dataset_model.overwrite_global_dimension("spectral")
decay_dataset_model.set_global_dimension("spectral")
decay_dataset_model.set_coordinates(axis)
matrix = calculate_matrix(decay_dataset_model, {})
decay_compartments = matrix.clp_labels
Expand Down
41 changes: 41 additions & 0 deletions glotaran/model/clp_guidance_megacomplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from glotaran.model.megacomplex import Megacomplex
from glotaran.model.megacomplex import megacomplex

if TYPE_CHECKING:
from glotaran.model.dataset_model import DatasetModel


@megacomplex()
class ClpGuidanceMegacomplex(Megacomplex):
def calculate_matrix(
self,
dataset_model: DatasetModel,
indices: dict[str, int],
**kwargs,
):
model_axis = dataset_model.get_model_axis()
clp_label = [dataset_model.clp_guidance]
matrix = np.ones((model_axis.size, 1), dtype=np.float64)
return clp_label, matrix

def index_dependent(self, dataset_model: DatasetModel) -> bool:
return False

def finalize_data(
self,
dataset_model: DatasetModel,
dataset: xr.Dataset,
is_full_model: bool = False,
as_global: bool = False,
):
if not is_full_model:
dataset["estimated_clp_guidance"] = dataset.clp.sel(
clp_label=dataset_model.clp_guidance
)
57 changes: 38 additions & 19 deletions glotaran/model/dataset_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import xarray as xr

from glotaran.model.clp_guidance_megacomplex import ClpGuidanceMegacomplex
from glotaran.model.item import model_item
from glotaran.model.item import model_item_validator

Expand Down Expand Up @@ -40,15 +41,19 @@ class DatasetModel:
def iterate_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's megacomplexes."""
for i, megacomplex in enumerate(self.megacomplex):
scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None
yield scale, megacomplex
"""Iterate over the dataset_models's megacomplexes."""
if self.clp_guidance is not None:
scale = self.megacomplex_scale or None
yield scale, ClpGuidanceMegacomplex()
else:
for i, megacomplex in enumerate(self.megacomplex):
scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None
yield scale, megacomplex

def iterate_global_megacomplexes(
self,
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]:
"""Iterates of der dataset model's global megacomplexes."""
"""Iterate over the dataset_models's global megacomplexes."""
for i, megacomplex in enumerate(self.global_megacomplex):
scale = (
self.global_megacomplex_scale[i]
Expand All @@ -57,47 +62,61 @@ def iterate_global_megacomplexes(
)
yield scale, megacomplex

def has_megacomplexes(self) -> bool:
return len(list(self.iterate_megacomplexes())) != 0

def has_global_megacomplexes(self) -> bool:
return len(list(self.iterate_global_megacomplexes())) != 0

def get_model_dimension(self) -> str:
"""Returns the dataset model's model dimension."""
if self.override_model_dimension is not None:
return self.override_model_dimension
if not hasattr(self, "_model_dimension"):
if len(self.megacomplex) == 0:
if not self.has_megacomplexes():
raise ValueError(f"No megacomplex set for dataset model '{self.label}'")
if isinstance(self.megacomplex[0], str):
first_megacomplex = next(self.iterate_megacomplexes())[1]
if isinstance(first_megacomplex, str):
raise ValueError(f"Dataset model '{self.label}' was not filled")
self._model_dimension = self.megacomplex[0].dimension
if any(self._model_dimension != m.dimension for m in self.megacomplex):
self._model_dimension = first_megacomplex.dimension
if any(self._model_dimension != m.dimension for _, m in self.iterate_megacomplexes()):
raise ValueError(
f"Megacomplex dimensions do not match for dataset model '{self.label}'."
)
return self._model_dimension

def finalize_data(self, dataset: xr.Dataset) -> None:
is_full_model = self.has_global_model()
for megacomplex in self.megacomplex:
for _, megacomplex in self.iterate_megacomplexes():
megacomplex.finalize_data(self, dataset, is_full_model=is_full_model)
if is_full_model:
for megacomplex in self.global_megacomplex:
for _, megacomplex in self.iterate_global_megacomplexes():
megacomplex.finalize_data(
self, dataset, is_full_model=is_full_model, as_global=True
)

def overwrite_model_dimension(self, model_dimension: str) -> None:
def set_model_dimension(self, model_dimension: str) -> None:
"""Overwrites the dataset model's model dimension."""
self._model_dimension = model_dimension

# TODO: make explicit we only support 2 dimensions at present
# TODO: the global dimension should become a flexible index (MultiIndex)
# the user can then specify the name of the MultiIndex global dimension
# using the function overwrite_global_dimension
# using the function override_global_dimension
# e.g. in FLIM, x, y dimension may get 'flattened' to a MultiIndex 'pixel'
def get_global_dimension(self) -> str:
"""Returns the dataset model's global dimension."""
if self.override_global_dimension is not None:
return self.override_global_dimension
if not hasattr(self, "_global_dimension"):
if self.has_global_model():
if isinstance(self.global_megacomplex[0], str):
raise ValueError(f"Dataset model '{self.label}' was not filled")
self._global_dimension = self.global_megacomplex[0].dimension
if any(self._global_dimension != m.dimension for m in self.global_megacomplex):
if any(
self._global_dimension != m.dimension
for _, m in self.iterate_global_megacomplexes()
):
raise ValueError(
"Global megacomplex dimensions do not "
f"match for dataset model '{self.label}'."
Expand All @@ -112,16 +131,16 @@ def get_global_dimension(self) -> str:
)
return self._global_dimension

def overwrite_global_dimension(self, global_dimension: str) -> None:
def set_global_dimension(self, global_dimension: str) -> None:
"""Overwrites the dataset model's global dimension."""
self._global_dimension = global_dimension

def swap_dimensions(self) -> None:
"""Swaps the dataset model's global and model dimension."""
global_dimension = self.get_model_dimension()
model_dimension = self.get_global_dimension()
self.overwrite_global_dimension(global_dimension)
self.overwrite_model_dimension(model_dimension)
self.set_global_dimension(global_dimension)
self.set_model_dimension(model_dimension)

def set_data(self, dataset: xr.Dataset) -> DatasetModel:
"""Sets the dataset model's data."""
Expand All @@ -144,7 +163,7 @@ def is_index_dependent(self) -> bool:
"""Indicates if the dataset model is index dependent."""
if hasattr(self, "_index_dependent"):
return self._index_dependent
return any(m.index_dependent(self) for m in self.megacomplex)
return any(m.index_dependent(self) for _, m in self.iterate_megacomplexes())

def overwrite_index_dependent(self, index_dependent: bool):
"""Overrides the index dependency of the dataset"""
Expand Down Expand Up @@ -186,7 +205,7 @@ def ensure_unique_megacomplexes(self, model: Model) -> list[str]:
"""
glotaran_unique_megacomplex_types = []

for megacomplex_name in self.megacomplex:
for _, megacomplex_name in self.iterate_megacomplexes():
try:
megacomplex_instance = model.megacomplex[megacomplex_name]
if type(megacomplex_instance).glotaran_unique() is True:
Expand Down
8 changes: 6 additions & 2 deletions glotaran/model/dataset_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]:
class DatasetModel:

label: str
clp_guidance: str | None
override_model_dimension: str | None
override_global_dimension: str | None
group: str
megacomplex: list[str]
megacomplex_scale: list[Parameter] | None
global_megacomplex: list[str]
Expand All @@ -30,9 +34,9 @@ class DatasetModel:
) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ...
def get_model_dimension(self) -> str: ...
def finalize_data(self, dataset: xr.Dataset) -> None: ...
def overwrite_model_dimension(self, model_dimension: str) -> None: ...
def set_model_dimension(self, model_dimension: str) -> None: ...
def get_global_dimension(self) -> str: ...
def overwrite_global_dimension(self, global_dimension: str) -> None: ...
def set_global_dimension(self, global_dimension: str) -> None: ...
def swap_dimensions(self) -> None: ...
def set_data(self, dataset: xr.Dataset) -> DatasetModel: ...
def get_data(self) -> np.ndarray: ...
Expand Down
26 changes: 15 additions & 11 deletions glotaran/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
}

default_dataset_properties = {
"clp_guidance": {"type": str, "allow_none": True},
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
"override_model_dimension": {"type": str, "allow_none": True},
"override_global_dimension": {"type": str, "allow_none": True},
"group": {"type": str, "default": "default"},
"megacomplex": List[str],
"megacomplex": {"type": List[str], "allow_none": True},
"megacomplex_scale": {"type": List[Parameter], "allow_none": True},
"global_megacomplex": {"type": List[str], "allow_none": True},
"global_megacomplex_scale": {"type": List[Parameter], "default": None, "allow_none": True},
"scale": {"type": Parameter, "default": None, "allow_none": True},
"global_megacomplex_scale": {"type": List[Parameter], "allow_none": True},
"scale": {"type": Parameter, "allow_none": True},
}


Expand Down Expand Up @@ -212,15 +215,17 @@ def _add_model_item(self, item_name: str, item: type):
def _add_dataset_property(self, property_name: str, dataset_property: dict[str, any]):
if property_name in self._dataset_properties:
known_type = (
self._dataset_properties[property_name]
if not isinstance(self._dataset_properties, dict)
else self._dataset_properties[property_name]["type"]
self._dataset_properties[property_name]["type"]
if isinstance(self._dataset_properties, dict)
else self._dataset_properties[property_name]
)

new_type = (
dataset_property
if not isinstance(dataset_property, dict)
else dataset_property["type"]
dataset_property["type"]
if isinstance(dataset_property, dict)
else dataset_property
)

if known_type != new_type:
raise ModelError(
f"Cannot add dataset property of type {property_name} as it was "
Expand Down Expand Up @@ -374,8 +379,7 @@ def validate(self, parameters: ParameterGroup = None, raise_exception: bool = Fa
"""
result = ""

problems = self.problem_list(parameters)
if problems:
if problems := self.problem_list(parameters):
result = f"Your model has {len(problems)} problems:\n"
for p in problems:
result += f"\n * {p}"
Expand Down