diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index 362972e4c..68d9c0941 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -4,12 +4,16 @@ from typing import List import numpy as np +import xarray as xr -from glotaran.builtin.megacomplexes.decay.decay_megacomplex_base import DecayMegacomplexBase from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration from glotaran.builtin.megacomplexes.decay.irf import Irf from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix +from glotaran.builtin.megacomplexes.decay.util import calculate_matrix +from glotaran.builtin.megacomplexes.decay.util import finalize_data +from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.model import DatasetModel +from glotaran.model import Megacomplex from glotaran.model import ModelError from glotaran.model import megacomplex @@ -26,7 +30,7 @@ }, register_as="decay", ) -class DecayMegacomplex(DecayMegacomplexBase): +class DecayMegacomplex(Megacomplex): """A Megacomplex with one or more K-Matrices.""" def get_compartments(self, dataset_model: DatasetModel) -> list[str]: @@ -57,3 +61,37 @@ def get_k_matrix(self) -> KMatrix: else: full_k_matrix = full_k_matrix.combine(k_matrix) return full_k_matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return index_dependent(dataset_model) + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + return calculate_matrix(self, dataset_model, indices, **kwargs) + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import ( + DecayParallelMegacomplex, + ) + from glotaran.builtin.megacomplexes.decay.decay_sequential_megacomplex import ( + DecaySequentialMegacomplex, + ) + + decay_megacomplexes = [ + m + for m in dataset_model.megacomplex + if isinstance( + m, (DecayMegacomplex, DecayParallelMegacomplex, DecaySequentialMegacomplex) + ) + ] + finalize_data(decay_megacomplexes, dataset_model, dataset, is_full_model, as_global) diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex_base.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex_base.py deleted file mode 100644 index 3ed91e7aa..000000000 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex_base.py +++ /dev/null @@ -1,117 +0,0 @@ -"""This package contains the decay megacomplex item.""" -from __future__ import annotations - -import numpy as np -import xarray as xr - -from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian -from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix -from glotaran.builtin.megacomplexes.decay.util import decay_matrix_implementation -from glotaran.builtin.megacomplexes.decay.util import retrieve_decay_associated_data -from glotaran.builtin.megacomplexes.decay.util import retrieve_irf -from glotaran.builtin.megacomplexes.decay.util import retrieve_species_associated_data -from glotaran.model import DatasetModel -from glotaran.model import Megacomplex - - -class DecayMegacomplexBase(Megacomplex): - """A Megacomplex with one or more K-Matrices.""" - - def get_compartments(self, dataset_model: DatasetModel) -> list[str]: - raise NotImplementedError - - def get_initial_concentration(self, dataset_model: DatasetModel) -> np.ndarray: - raise NotImplementedError - - def get_k_matrix(self) -> KMatrix: - raise NotImplementedError - - def index_dependent(self, dataset_model: DatasetModel) -> bool: - return ( - isinstance(dataset_model.irf, IrfMultiGaussian) - and dataset_model.irf.is_index_dependent() - ) - - def calculate_matrix( - self, - dataset_model: DatasetModel, - indices: dict[str, int], - **kwargs, - ): - - compartments = self.get_compartments(dataset_model) - initial_concentration = self.get_initial_concentration(dataset_model) - k_matrix = self.get_k_matrix() - - # the rates are the eigenvalues of the k matrix - rates = k_matrix.rates(compartments, initial_concentration) - - global_dimension = dataset_model.get_global_dimension() - global_index = indices.get(global_dimension) - global_axis = dataset_model.get_global_axis() - model_axis = dataset_model.get_model_axis() - - # init the matrix - size = (model_axis.size, rates.size) - matrix = np.zeros(size, dtype=np.float64) - - decay_matrix_implementation( - matrix, rates, global_index, global_axis, model_axis, dataset_model - ) - - if not np.all(np.isfinite(matrix)): - raise ValueError( - f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" - f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" - ) - - # apply A matrix - matrix = matrix @ k_matrix.a_matrix(compartments, initial_concentration) - - # done - return compartments, matrix - - def finalize_data( - self, - dataset_model: DatasetModel, - dataset: xr.Dataset, - is_full_model: bool = False, - as_global: bool = False, - ): - global_dimension = dataset_model.get_global_dimension() - name = "images" if global_dimension == "pixel" else "spectra" - decay_megacomplexes = [ - m for m in dataset_model.megacomplex if isinstance(m, DecayMegacomplexBase) - ] - - species_dimension = "decay_species" if as_global else "species" - if species_dimension not in dataset.coords: - # We are the first Decay complex called and add SAD for all decay megacomplexes - all_species = [] - for megacomplex in decay_megacomplexes: - for species in megacomplex.get_compartments(dataset_model): - if species not in all_species: - all_species.append(species) - retrieve_species_associated_data( - dataset_model, - dataset, - all_species, - species_dimension, - global_dimension, - name, - is_full_model, - as_global, - ) - if isinstance(dataset_model.irf, IrfMultiGaussian) and "irf" not in dataset: - retrieve_irf(dataset_model, dataset, global_dimension) - - if not is_full_model: - multiple_complexes = len(decay_megacomplexes) > 1 - retrieve_decay_associated_data( - self, - dataset_model, - dataset, - global_dimension, - name, - multiple_complexes, - ) diff --git a/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py index ee073e8b7..a64d5c0ac 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py @@ -4,11 +4,15 @@ from typing import List import numpy as np +import xarray as xr -from glotaran.builtin.megacomplexes.decay.decay_megacomplex_base import DecayMegacomplexBase from glotaran.builtin.megacomplexes.decay.irf import Irf from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix +from glotaran.builtin.megacomplexes.decay.util import calculate_matrix +from glotaran.builtin.megacomplexes.decay.util import finalize_data +from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.model import DatasetModel +from glotaran.model import Megacomplex from glotaran.model import megacomplex from glotaran.parameter import Parameter @@ -24,7 +28,7 @@ }, register_as="decay-parallel", ) -class DecayParallelMegacomplex(DecayMegacomplexBase): +class DecayParallelMegacomplex(Megacomplex): def get_compartments(self, dataset_model: DatasetModel) -> list[str]: return self.compartments @@ -38,3 +42,35 @@ def get_k_matrix(self) -> KMatrix: (self.compartments[i], self.compartments[i]): self.rates[i] for i in range(size) } return k_matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return index_dependent(dataset_model) + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + return calculate_matrix(self, dataset_model, indices, **kwargs) + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex + from glotaran.builtin.megacomplexes.decay.decay_sequential_megacomplex import ( + DecaySequentialMegacomplex, + ) + + decay_megacomplexes = [ + m + for m in dataset_model.megacomplex + if isinstance( + m, (DecayMegacomplex, DecayParallelMegacomplex, DecaySequentialMegacomplex) + ) + ] + finalize_data(decay_megacomplexes, dataset_model, dataset, is_full_model, as_global) diff --git a/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py index ad8a87280..8824c1764 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py @@ -4,11 +4,15 @@ from typing import List import numpy as np +import xarray as xr -from glotaran.builtin.megacomplexes.decay.decay_megacomplex_base import DecayMegacomplexBase from glotaran.builtin.megacomplexes.decay.irf import Irf from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix +from glotaran.builtin.megacomplexes.decay.util import calculate_matrix +from glotaran.builtin.megacomplexes.decay.util import finalize_data +from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.model import DatasetModel +from glotaran.model import Megacomplex from glotaran.model import megacomplex from glotaran.parameter import Parameter @@ -24,7 +28,7 @@ }, register_as="decay-sequential", ) -class DecaySequentialMegacomplex(DecayMegacomplexBase): +class DecaySequentialMegacomplex(Megacomplex): """A Megacomplex with one or more K-Matrices.""" def get_compartments(self, dataset_model: DatasetModel) -> list[str]: @@ -44,3 +48,35 @@ def get_k_matrix(self) -> KMatrix: } k_matrix.matrix[self.compartments[-1], self.compartments[-1]] = self.rates[-1] return k_matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return index_dependent(dataset_model) + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + return calculate_matrix(self, dataset_model, indices, **kwargs) + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex + from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import ( + DecayParallelMegacomplex, + ) + + decay_megacomplexes = [ + m + for m in dataset_model.megacomplex + if isinstance( + m, (DecayMegacomplex, DecayParallelMegacomplex, DecaySequentialMegacomplex) + ) + ] + finalize_data(decay_megacomplexes, dataset_model, dataset, is_full_model, as_global) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index 4da323589..a400aac7e 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numba as nb import numpy as np import xarray as xr @@ -9,9 +7,96 @@ from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian from glotaran.model import DatasetModel +from glotaran.model import Megacomplex + + +def index_dependent(dataset_model: DatasetModel) -> bool: + return ( + isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.is_index_dependent() + ) + + +def calculate_matrix( + megacomplex: Megacomplex, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, +): + + compartments = megacomplex.get_compartments(dataset_model) + initial_concentration = megacomplex.get_initial_concentration(dataset_model) + k_matrix = megacomplex.get_k_matrix() + + # the rates are the eigenvalues of the k matrix + rates = k_matrix.rates(compartments, initial_concentration) + + global_dimension = dataset_model.get_global_dimension() + global_index = indices.get(global_dimension) + global_axis = dataset_model.get_global_axis() + model_axis = dataset_model.get_model_axis() + + # init the matrix + size = (model_axis.size, rates.size) + matrix = np.zeros(size, dtype=np.float64) -if TYPE_CHECKING: - from glotaran.builtin.megacomplexes.decay.decay_megacomplex_base import DecayMegacomplexBase + decay_matrix_implementation( + matrix, rates, global_index, global_axis, model_axis, dataset_model + ) + + if not np.all(np.isfinite(matrix)): + raise ValueError( + f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" + f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" + ) + + # apply A matrix + matrix = matrix @ k_matrix.a_matrix(compartments, initial_concentration) + + # done + return compartments, matrix + + +def finalize_data( + decay_megacomplexes: list[Megacomplex], + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, +): + global_dimension = dataset_model.get_global_dimension() + name = "images" if global_dimension == "pixel" else "spectra" + + species_dimension = "decay_species" if as_global else "species" + if species_dimension not in dataset.coords: + # We are the first Decay complex called and add SAD for all decay megacomplexes + all_species = [] + for megacomplex in decay_megacomplexes: + for species in megacomplex.get_compartments(dataset_model): + if species not in all_species: + all_species.append(species) + retrieve_species_associated_data( + dataset_model, + dataset, + all_species, + species_dimension, + global_dimension, + name, + is_full_model, + as_global, + ) + if isinstance(dataset_model.irf, IrfMultiGaussian) and "irf" not in dataset: + retrieve_irf(dataset_model, dataset, global_dimension) + + if not is_full_model: + multiple_complexes = len(decay_megacomplexes) > 1 + retrieve_decay_associated_data( + megacomplex, + dataset_model, + dataset, + global_dimension, + name, + multiple_complexes, + ) def decay_matrix_implementation( @@ -153,7 +238,7 @@ def retrieve_species_associated_data( def retrieve_decay_associated_data( - megacomplex: DecayMegacomplexBase, + megacomplex: Megacomplex, dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str,