diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 004e3746e..9c56f7e58 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -56,13 +56,23 @@ def index_dependent(self, dataset_model): class SimpleTestModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "model_complex": SimpleTestMegacomplex, + "global_complex": SimpleTestMegacomplexGlobal, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "model_complex": SimpleTestMegacomplex, - "global_complex": SimpleTestMegacomplexGlobal, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) @@ -157,14 +167,24 @@ def finalize_data( class DecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "model_complex": SimpleKineticMegacomplex, + "global_complex": SimpleSpectralMegacomplex, + "global_complex_shaped": ShapedSpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "model_complex": SimpleKineticMegacomplex, - "global_complex": SimpleSpectralMegacomplex, - "global_complex_shaped": ShapedSpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index 25b44c3b2..d54cb06b0 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -71,15 +71,7 @@ def calculate_matrix( clp_labels = this_clp_labels matrix = this_matrix else: - tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] - tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) - for idx, label in enumerate(tmp_clp_labels): - if label in clp_labels: - tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] - if label in this_clp_labels: - tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] - clp_labels = tmp_clp_labels - matrix = tmp_matrix + clp_labels, matrix = combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels) if as_global_model: dataset_model.swap_dimensions() @@ -87,6 +79,17 @@ def calculate_matrix( return CalculatedMatrix(clp_labels, matrix) +def combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels): + tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] + tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) + for idx, label in enumerate(tmp_clp_labels): + if label in clp_labels: + tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] + if label in this_clp_labels: + tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] + return tmp_clp_labels, tmp_matrix + + @nb.jit(nopython=True, parallel=True) def apply_weight(matrix, weight): for i in nb.prange(matrix.shape[1]): diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py b/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py new file mode 100755 index 000000000..2f975d246 --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py @@ -0,0 +1,3 @@ +from glotaran.builtin.megacomplexes.damped_oscillation.damped_oscillation_megacomplex import ( + DampedOscillationMegacomplex, +) diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py new file mode 100644 index 000000000..d57996a6f --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from typing import List + +import numba as nb +import numpy as np +import xarray as xr +from scipy.special import erf + +from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.model import megacomplex +from glotaran.model.item import model_item_validator +from glotaran.parameter import Parameter + + +@megacomplex( + dimension="time", + dataset_model_items={ + "irf": {"type": Irf, "allow_none": True}, + }, + properties={ + "labels": List[str], + "frequencies": List[Parameter], + "rates": List[Parameter], + }, + register_as="damped-oscillation", +) +class DampedOscillationMegacomplex(Megacomplex): + @model_item_validator(False) + def ensure_oscillation_paramater(self, model: Model) -> list[str]: + + problems = [] + + if len(self.labels) != len(self.frequencies) or len(self.labels) != len(self.rates): + problems.append( + f"Size of labels ({len(self.labels)}), frequencies ({len(self.frequencies)}) " + f"and rates ({len(self.rates)}) does not match for damped oscillation " + f"megacomplex '{self.label}'." + ) + + return problems + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + + clp_label = [f"{label}_cos" for label in self.labels] + [ + f"{label}_sin" for label in self.labels + ] + + model_axis = dataset_model.get_model_axis() + delta = np.abs(model_axis[1:] - model_axis[:-1]) + delta_min = delta[np.argmin(delta)] + frequency_max = 1 / (2 * 0.03 * delta_min) + frequencies = np.array(self.frequencies) * 0.03 * 2 * np.pi + frequencies[frequencies >= frequency_max] = np.mod( + frequencies[frequencies >= frequency_max], frequency_max + ) + rates = np.array(self.rates) + + matrix = np.ones((model_axis.size, len(clp_label)), dtype=np.float64) + + if dataset_model.irf is None: + calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, model_axis) + elif isinstance(dataset_model.irf, IrfMultiGaussian): + global_dimension = dataset_model.get_global_dimension() + global_axis = dataset_model.get_global_axis() + global_index = indices.get(global_dimension) + centers, widths, scales, shift, _, _ = dataset_model.irf.parameter( + global_index, global_axis + ) + for center, width, scale in zip(centers, widths, scales): + matrix += calculate_damped_oscillation_matrix_gaussian_irf( + frequencies, + rates, + model_axis, + center, + width, + shift, + scale, + ) + matrix /= np.sum(scales) + + return clp_label, matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return ( + isinstance(dataset_model.irf, IrfMultiGaussian) + and dataset_model.irf.is_index_dependent() + ) + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + if is_full_model: + return + + megacomplexes = ( + dataset_model.global_megacomplex if is_full_model else dataset_model.megacomplex + ) + unique = len([m for m in megacomplexes if isinstance(m, DampedOscillationMegacomplex)]) < 2 + + prefix = "damped_oscillation" if unique else f"{self.label}_damped_oscillation" + + dataset.coords[f"{prefix}"] = self.labels + dataset.coords[f"{prefix}_frequency"] = (prefix, self.frequencies) + dataset.coords[f"{prefix}_rate"] = (prefix, self.rates) + + dim1 = dataset_model.get_global_axis().size + dim2 = len(self.labels) + doas = np.zeros((dim1, dim2), dtype=np.float64) + phase = np.zeros((dim1, dim2), dtype=np.float64) + for i, label in enumerate(self.labels): + sin = dataset.clp.sel(clp_label=f"{label}_sin") + cos = dataset.clp.sel(clp_label=f"{label}_cos") + doas[:, i] = np.sqrt(sin * sin + cos * cos) + phase[:, i] = np.unwrap(np.arctan2(sin, cos)) + + dataset[f"{prefix}_associated_spectra"] = ( + (dataset_model.get_global_dimension(), prefix), + doas, + ) + + dataset[f"{prefix}_phase"] = ( + (dataset_model.get_global_dimension(), prefix), + phase, + ) + + if not is_full_model: + if self.index_dependent(dataset_model): + dataset[f"{prefix}_sin"] = ( + ( + dataset_model.get_global_dimension(), + dataset_model.get_model_dimension(), + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values, + ) + + dataset[f"{prefix}_cos"] = ( + ( + dataset_model.get_global_dimension(), + dataset_model.get_model_dimension(), + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values, + ) + else: + dataset[f"{prefix}_sin"] = ( + (dataset_model.get_model_dimension(), prefix), + dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values, + ) + + dataset[f"{prefix}_cos"] = ( + (dataset_model.get_model_dimension(), prefix), + dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values, + ) + + +@nb.jit(nopython=True, parallel=True) +def calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, axis): + + idx = 0 + for frequency, rate in zip(frequencies, rates): + osc = np.exp(-rate * axis - 1j * frequency * axis) + matrix[:, idx] = osc.real + matrix[:, idx + 1] = osc.imag + idx += 2 + + +def calculate_damped_oscillation_matrix_gaussian_irf( + frequencies: np.ndarray, + rates: np.ndarray, + model_axis: np.ndarray, + center: float, + width: float, + shift: float, + scale: float, +): + shifted_axis = model_axis - center - shift + d = width ** 2 + k = rates + 1j * frequencies + dk = k * d + sqwidth = np.sqrt(2) * width + a = (-1 * shifted_axis[:, None] + 0.5 * dk) * k + a = np.minimum(a, 709) + a = np.exp(a) + b = 1 + erf((shifted_axis[:, None] - dk) / sqwidth) + osc = a * b * scale + return np.concatenate((osc.real, osc.imag), axis=1) diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py new file mode 100755 index 000000000..fe1c29927 --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py @@ -0,0 +1,409 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from glotaran.analysis.optimize import optimize +from glotaran.analysis.simulation import simulate +from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + + +class DampedOscillationsModel(Model): + @classmethod + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "damped_oscillation": DampedOscillationMegacomplex, + "decay": DecayMegacomplex, + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) + return super().from_dict( + model_dict, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, + ) + + +class OneOscillation: + sim_model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m2": { + "type": "spectral", + "shape": {"osc1_cos": "sh1", "osc1_sin": "sh1"}, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + }, + "dataset": {"dataset1": {"megacomplex": ["m1"], "global_megacomplex": ["m2"]}}, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "dataset": {"dataset1": {"megacomplex": ["m1"]}}, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25.5], + ["rate", 0.1], + ], + "shapes": {"amps": [7], "locs": [5], "width": [4]}, + } + ) + + parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 20], + ["rate", 0.3], + ], + } + ) + + time = np.arange(0, 3, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin"] + wanted_shape = (300, 2) + + +class OneOscillationWithIrf: + sim_model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m2": { + "type": "spectral", + "shape": {"osc1_cos": "sh1", "osc1_sin": "sh1"}, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "irf": "irf1", + } + }, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "irf": "irf1", + } + }, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "shapes": {"amps": [7], "locs": [5], "width": [4]}, + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + time = np.arange(0, 3, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin"] + wanted_shape = (300, 2) + + +class OneOscillationWithSequentialModel: + sim_model = DampedOscillationsModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m3": { + "type": "spectral", + "shape": { + "osc1_cos": "sh1", + "osc1_sin": "sh1", + "s1": "sh2", + "s2": "sh3", + }, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + "sh2": { + "type": "gaussian", + "amplitude": "shapes.amps.2", + "location": "shapes.locs.2", + "width": "shapes.width.2", + }, + "sh3": { + "type": "gaussian", + "amplitude": "shapes.amps.3", + "location": "shapes.locs.3", + "width": "shapes.width.3", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m1", "m2"], + "global_megacomplex": ["m3"], + } + }, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m1", "m2"], + } + }, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.2], + ["2", 0.01], + ], + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "shapes": {"amps": [0.07, 2, 4], "locs": [5, 2, 8], "width": [4, 2, 3]}, + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + parameter = ParameterGroup.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.2], + ["2", 0.01], + ], + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + time = np.arange(-1, 5, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin", "s1", "s2"] + wanted_shape = (600, 4) + + +@pytest.mark.parametrize( + "suite", + [ + OneOscillation, + OneOscillationWithIrf, + OneOscillationWithSequentialModel, + ], +) +def test_doas_model(suite): + + print(suite.sim_model.validate()) # noqa + assert suite.sim_model.valid() + + print(suite.model.validate()) # noqa + assert suite.model.valid() + + print(suite.sim_model.validate(suite.wanted_parameter)) # noqa + assert suite.sim_model.valid(suite.wanted_parameter) + + print(suite.model.validate(suite.parameter)) # noqa + assert suite.model.valid(suite.parameter) + + dataset = simulate(suite.sim_model, "dataset1", suite.wanted_parameter, suite.axis) + print(dataset) # noqa + + assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) + + print(suite.parameter) # noqa + print(suite.wanted_parameter) # noqa + + data = {"dataset1": dataset} + scheme = Scheme( + model=suite.model, + parameters=suite.parameter, + data=data, + maximum_number_function_evaluations=20, + ) + result = optimize(scheme, raise_exception=True) + print(result.optimized_parameters) # noqa + + for label, param in result.optimized_parameters.all(): + assert np.allclose(param.value, suite.wanted_parameter.get(label).value, rtol=1e-1) + + resultdata = result.data["dataset1"] + assert np.array_equal(dataset["time"], resultdata["time"]) + assert np.array_equal(dataset["spectral"], resultdata["spectral"]) + assert dataset.data.shape == resultdata.fitted_data.shape + assert np.allclose(dataset.data, resultdata.fitted_data) + + assert "damped_oscillation_cos" in resultdata + assert "damped_oscillation_sin" in resultdata + assert "damped_oscillation_associated_spectra" in resultdata + assert "damped_oscillation_phase" in resultdata diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index e57f31f5e..cbc55b165 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -9,7 +9,6 @@ from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration from glotaran.builtin.megacomplexes.decay.irf import Irf from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian -from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix from glotaran.builtin.megacomplexes.decay.util import decay_matrix_implementation from glotaran.builtin.megacomplexes.decay.util import retrieve_decay_associated_data @@ -57,10 +56,8 @@ def involved_compartments(self): def index_dependent(self, dataset_model: DatasetModel) -> bool: return ( - isinstance(dataset_model.irf, IrfSpectralMultiGaussian) - and dataset_model.irf.dispersion_center is not None - ) or ( - isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.shift is not None + isinstance(dataset_model.irf, IrfMultiGaussian) + and dataset_model.irf.is_index_dependent() ) def calculate_matrix( diff --git a/glotaran/builtin/megacomplexes/decay/irf.py b/glotaran/builtin/megacomplexes/decay/irf.py index e1297c74f..7ddbc17f6 100644 --- a/glotaran/builtin/megacomplexes/decay/irf.py +++ b/glotaran/builtin/megacomplexes/decay/irf.py @@ -106,6 +106,9 @@ def calculate(self, index: int, global_axis: np.ndarray, model_axis: np.ndarray) for center, width, scale in zip(centers, widths, scales) ) + def is_index_dependent(self): + return self.shift is not None + @model_item( properties={ @@ -191,6 +194,9 @@ def calculate_dispersion(self, axis): dispersion.append(center) return np.asarray(dispersion).T + def is_index_dependent(self): + return super().is_index_dependent() or self.dispersion_center is not None + @model_item( properties={ diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index 32e266128..f788a6b90 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pytest import xarray as xr @@ -5,6 +7,7 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -22,12 +25,22 @@ def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): class DecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "decay": DecayMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "decay": DecayMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index b984fe4cd..c51a3ce67 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -173,7 +173,11 @@ def retrieve_decay_associated_data( das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T - component_coords = {"rate": ("component", rates), "lifetime": ("component", lifetimes)} + component_coords = { + "component": np.arange(rates.size), + "rate": ("component", rates), + "lifetime": ("component", lifetimes), + } das_coords = component_coords.copy() das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}" diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index ce559a31d..4b5659943 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pytest import xarray as xr @@ -7,6 +9,7 @@ from glotaran.analysis.util import calculate_matrix from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import DecayModel from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -14,12 +17,22 @@ class SpectralModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "spectral": SpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/model/model.py b/glotaran/model/model.py index 60ec800b0..13c7cf24c 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -56,7 +56,7 @@ def __init__( @classmethod def from_dict( cls, - model_dict_ref: dict, + model_dict: dict, *, megacomplex_types: dict[str, type[Megacomplex]], default_megacomplex_type: str | None = None, @@ -73,10 +73,10 @@ def from_dict( megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type ) - model_dict = copy.deepcopy(model_dict_ref) + model_dict_local = copy.deepcopy(model_dict) # TODO: maybe redundant? # iterate over items - for name, items in list(model_dict.items()): + for name, items in list(model_dict_local.items()): if name not in model._model_items: warn(f"Unknown model item type '{name}'.") diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index 5b055aedc..a9a870555 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -1,6 +1,5 @@ -# To add a new cell, type '# %%' -# To add a new markdown cell, type '# %% [markdown]' -# %% +from __future__ import annotations + import importlib from collections import namedtuple from copy import deepcopy @@ -12,6 +11,7 @@ from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.io import prepare_time_trace_dataset +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -28,13 +28,23 @@ class SpectralDecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "decay": DecayMegacomplex, + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "decay": DecayMegacomplex, - "spectral": SpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/setup.cfg b/setup.cfg index 9b3a44fbc..2d7b1e90b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,6 +61,7 @@ glotaran.plugins.data_io = glotaran.plugins.megacomplexes = baseline = glotaran.builtin.megacomplexes.baseline coherent_artifact = glotaran.builtin.megacomplexes.coherent_artifact + damped_oscillation = glotaran.builtin.megacomplexes.damped_oscillation decay = glotaran.builtin.megacomplexes.decay spectral = glotaran.builtin.megacomplexes.spectral glotaran.plugins.project_io =