From de4b949c0ffb42b5ef382f1378b0d44ae95384b4 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 27 Jun 2021 01:18:23 +0200 Subject: [PATCH 01/29] =?UTF-8?q?=F0=9F=9A=87=20Raised=20benchmark=20timeo?= =?UTF-8?q?ut=20to=205min=20(#724)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../benchmarks/integration/ex_two_datasets/benchmark.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py b/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py index f399f6bfc..e6279b1f8 100644 --- a/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py +++ b/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py @@ -71,3 +71,9 @@ def peakmem_create_result(self): _create_result( self.problem, self.ls_result, self.free_parameter_labels, self.termination_reason ) + + +if __name__ == "__main__": + test = IntegrationTwoDatasets() + test.setup() + test.time_optimize() From ba530299f42eafd3b15c9322b29568d414efb701 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 4 Jul 2021 02:20:05 +0200 Subject: [PATCH 02/29] =?UTF-8?q?=E2=9C=A8=20Use=20xarray=20internally=20a?= =?UTF-8?q?nd=20move=20relations/constraints/penaltys=20to=20glotaran.mode?= =?UTF-8?q?l=20(#734)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Changed calculate_matrix function to return xarrays * Added relation and constraint to basemodel * Added functions to apply relations and constraints * Removed LabelAndMatrix class * Added relations and constraints tests to base model * Added penalties to base model * Adapted models to changes * Address xarray deprecation warnings Fixed DeprecationWarning: Using a DataArray object to construct a variable is ambiguous, please extract the data using the .data property. This will raise a TypeError in 0.19.0. Co-authored-by: Jörn Weißenborn Co-authored-by: Joris Snellenburg --- glotaran/analysis/problem.py | 53 +-- glotaran/analysis/problem_grouped.py | 364 ++++++++---------- glotaran/analysis/problem_ungrouped.py | 160 ++++---- glotaran/analysis/simulation.py | 29 +- glotaran/analysis/test/models.py | 28 +- glotaran/analysis/test/test_constraints.py | 44 +++ glotaran/analysis/test/test_optimization.py | 23 +- glotaran/analysis/test/test_penalties.py | 53 +++ glotaran/analysis/test/test_problem.py | 33 +- glotaran/analysis/test/test_relations.py | 46 +++ glotaran/analysis/util.py | 212 ++++++---- .../kinetic_baseline_megacomplex.py | 12 +- .../kinetic_decay_megacomplex.py | 11 +- .../kinetic_image/test/test_baseline.py | 17 +- .../coherent_artifact_megacomplex.py | 12 +- .../test/test_coherent_artifact.py | 11 +- .../test/test_spectral_constraints.py | 113 ------ .../test/test_spectral_penalties.py | 35 +- .../test/test_spectral_relations.py | 138 ------- .../models/spectral/spectral_megacomplex.py | 14 +- .../spectral/test/test_spectral_model.py | 30 +- glotaran/model/__init__.py | 5 + glotaran/model/clp_penalties.py | 161 ++++++++ glotaran/model/constraint.py | 66 ++++ glotaran/model/dataset_descriptor.py | 11 + glotaran/model/decorator.py | 6 + glotaran/model/interval_property.py | 44 +++ glotaran/model/megacomplex.py | 8 +- glotaran/model/relation.py | 21 + tox.ini | 4 + 30 files changed, 977 insertions(+), 787 deletions(-) create mode 100644 glotaran/analysis/test/test_constraints.py create mode 100644 glotaran/analysis/test/test_penalties.py create mode 100644 glotaran/analysis/test/test_relations.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/test/test_spectral_constraints.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/test/test_spectral_relations.py create mode 100644 glotaran/model/clp_penalties.py create mode 100644 glotaran/model/constraint.py create mode 100644 glotaran/model/interval_property.py create mode 100644 glotaran/model/relation.py diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index 8243f3baa..71cfe8f2f 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -92,9 +92,7 @@ def __init__(self, scheme: Scheme): # all of the above are always not None - self._clp_labels = None self._matrices = None - self._reduced_clp_labels = None self._reduced_matrices = None self._reduced_clps = None self._clps = None @@ -166,14 +164,6 @@ def groups(self) -> dict[str, list[str]]: self.init_bag() return self._groups - @property - def clp_labels( - self, - ) -> dict[str, list[str] | list[list[str]]]: - if self._clp_labels is None: - self.calculate_matrices() - return self._clp_labels - @property def matrices( self, @@ -182,14 +172,6 @@ def matrices( self.calculate_matrices() return self._matrices - @property - def reduced_clp_labels( - self, - ) -> dict[str, list[str] | list[list[str]]]: - if self._reduced_clp_labels is None: - self.calculate_matrices() - return self._reduced_clp_labels - @property def reduced_matrices( self, @@ -235,23 +217,12 @@ def additional_penalty( self, ) -> dict[str, list[float]]: if self._additional_penalty is None: - self.calculate_additional_penalty() + self.calculate_residual() return self._additional_penalty @property def full_penalty(self) -> np.ndarray: - if self._full_penalty is None: - residuals = self.weighted_residuals - additional_penalty = self.additional_penalty - if not self.grouped: - residuals = [np.concatenate(residuals[label]) for label in residuals.keys()] - - self._full_penalty = ( - np.concatenate((np.concatenate(residuals), additional_penalty)) - if additional_penalty is not None - else np.concatenate(residuals) - ) - return self._full_penalty + raise NotImplementedError @property def cost(self) -> float: @@ -272,9 +243,7 @@ def reset(self): self._reset_results() def _reset_results(self): - self._clp_labels = None self._matrices = None - self._reduced_clp_labels = None self._reduced_matrices = None self._reduced_clps = None self._clps = None @@ -372,24 +341,6 @@ def calculate_matrices(self): def calculate_residual(self): raise NotImplementedError - def calculate_additional_penalty(self) -> np.ndarray | dict[str, np.ndarray]: - """Calculates additional penalties by calling the model.additional_penalty function.""" - if ( - callable(self.model.has_additional_penalty_function) - and self.model.has_additional_penalty_function() - ): - self._additional_penalty = self.model.additional_penalty_function( - self.parameters, - self.clp_labels, - self.clps, - self.matrices, - self.data, - self._scheme.group_tolerance, - ) - else: - self._additional_penalty = None - return self._additional_penalty - def create_result_data( self, copy: bool = True, history_index: int | None = None ) -> dict[str, xr.Dataset]: diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index b4a180a71..1708f0841 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -11,12 +11,12 @@ from glotaran.analysis.problem import ParameterError from glotaran.analysis.problem import Problem from glotaran.analysis.problem import ProblemGroup -from glotaran.analysis.util import LabelAndMatrix +from glotaran.analysis.util import calculate_clp_penalties from glotaran.analysis.util import calculate_matrix -from glotaran.analysis.util import combine_matrices from glotaran.analysis.util import find_closest_index from glotaran.analysis.util import find_overlap from glotaran.analysis.util import reduce_matrix +from glotaran.analysis.util import retrieve_clps from glotaran.model import DatasetDescriptor from glotaran.project import Scheme @@ -53,6 +53,7 @@ def __init__(self, scheme: Scheme): ) self._global_dimension = global_dimensions.pop() self._model_dimension = model_dimensions.pop() + self._group_clp_labels = None def init_bag(self): """Initializes a grouped problem bag.""" @@ -190,229 +191,213 @@ def calculate_matrices(self): def calculate_index_dependent_matrices( self, - ) -> tuple[ - dict[str, list[list[str]]], - dict[str, list[np.ndarray]], - list[list[str]], - list[np.ndarray], - ]: + ) -> tuple[dict[str, list[np.ndarray]], list[np.ndarray],]: """Calculates the index dependent model matrices.""" def calculate_group( group: ProblemGroup, descriptors: dict[str, DatasetDescriptor] - ) -> tuple[list[tuple[LabelAndMatrix, str]], float]: - result = [ - ( - calculate_matrix( - self._model, - descriptors[problem.label], - problem.indices, - problem.axis, - ), - problem.label, + ) -> tuple[list[xr.DataArray], xr.DataArray, xr.DataArray]: + matrices = [ + calculate_matrix( + descriptors[problem.label], + problem.indices, ) for problem in group.descriptor ] global_index = group.descriptor[0].indices[self._global_dimension] global_index = group.descriptor[0].axis[self._global_dimension][global_index] - return result, global_index - - def reduce_and_combine_matrices( - results: tuple[list[tuple[LabelAndMatrix, str]], float], - ) -> LabelAndMatrix: - index_results, index = results - constraint_labels_and_matrices = list( - map( - lambda result: reduce_matrix( - self._model, result[1], self.parameters, result[0], index - ), - index_results, - ) + combined_matrix = xr.concat(matrices, dim=self._model_dimension).fillna(0) + group_clp_labels = combined_matrix.coords["clp_label"] + reduced_matrix = reduce_matrix( + combined_matrix, self.model, self.parameters, self._model_dimension, global_index ) - clp, matrix = combine_matrices(constraint_labels_and_matrices) - return LabelAndMatrix(clp, matrix) + return matrices, group_clp_labels, reduced_matrix results = list( map(lambda group: calculate_group(group, self._filled_dataset_descriptors), self._bag) ) - clp_labels = list(map(lambda result: [r[0].clp_label for r in result[0]], results)) - matrices = list(map(lambda result: [r[0].matrix for r in result[0]], results)) + matrices = list(map(lambda result: result[0], results)) - self._clp_labels = {} self._matrices = {} - for i, grouped_problem in enumerate(self._bag): for j, descriptor in enumerate(grouped_problem.descriptor): - if descriptor.label not in self._clp_labels: - self._clp_labels[descriptor.label] = [] + if descriptor.label not in self._matrices: self._matrices[descriptor.label] = [] - self._clp_labels[descriptor.label].append(clp_labels[i][j]) self._matrices[descriptor.label].append(matrices[i][j]) - reduced_results = list(map(reduce_and_combine_matrices, results)) - self._reduced_clp_labels = list(map(lambda result: result.clp_label, reduced_results)) - self._reduced_matrices = list(map(lambda result: result.matrix, reduced_results)) - return self._clp_labels, self._matrices, self._reduced_clp_labels, self._reduced_matrices + self._group_clp_labels = list(map(lambda result: result[1], results)) + self._reduced_matrices = list(map(lambda result: result[2], results)) + return self._matrices, self._reduced_matrices def calculate_index_independent_matrices( self, - ) -> tuple[dict[str, list[str]], dict[str, np.ndarray], dict[str, LabelAndMatrix],]: + ) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray],]: """Calculates the index independent model matrices.""" - self._clp_labels = {} self._matrices = {} - self._reduced_clp_labels = {} + self._group_clp_labels = {} self._reduced_matrices = {} - for label, descriptor in self._filled_dataset_descriptors.items(): - model_axis = self._data[label].coords[self._model_dimension].values - global_axis = self._data[label].coords[self._global_dimension].values - result = calculate_matrix( - self._model, - descriptor, + for label, dataset_model in self._filled_dataset_descriptors.items(): + self._matrices[label] = calculate_matrix( + dataset_model, {}, - { - self._model_dimension: model_axis, - self._global_dimension: global_axis, - }, ) - - self._clp_labels[label] = result.clp_label - self._matrices[label] = result.matrix - reduced_result = reduce_matrix(self._model, label, self._parameters, result, None) - self._reduced_clp_labels[label] = reduced_result.clp_label - self._reduced_matrices[label] = reduced_result.matrix + self._group_clp_labels[label] = self._matrices[label].coords["clp_label"] + self._reduced_matrices[label] = reduce_matrix( + self._matrices[label], + self.model, + self.parameters, + self._model_dimension, + None, + ) for group_label, group in self.groups.items(): if group_label not in self._matrices: - reduced_labels_and_matrix = combine_matrices( - [ - LabelAndMatrix( - self._reduced_clp_labels[label], self._reduced_matrices[label] - ) - for label in group - ] + self._reduced_matrices[group_label] = xr.concat( + [self._reduced_matrices[label] for label in group], dim=self._model_dimension + ).fillna(0) + group_clp_labels = xr.align( + *(self._matrices[label].coords["clp_label"] for label in group), join="outer" ) - self._reduced_clp_labels[group_label] = reduced_labels_and_matrix.clp_label - self._reduced_matrices[group_label] = reduced_labels_and_matrix.matrix + self._group_clp_labels[group_label] = group_clp_labels[0].coords["clp_label"] - return self._clp_labels, self._matrices, self._reduced_clp_labels, self._reduced_matrices + return self._matrices, self._reduced_matrices def calculate_residual(self): - if self._index_dependent: - self.calculate_index_dependent_residual() - else: - self.calculate_index_independent_residual() - - def calculate_index_dependent_residual( - self, - ) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray], list[np.ndarray],]: - """Calculates the index dependent residuals.""" - - def residual_function( - problem: ProblemGroup, matrix: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - - matrix = matrix.copy() - for i in range(matrix.shape[1]): - matrix[:, i] *= problem.weight - data = problem.data - if problem.has_scaling: - for i, descriptor in enumerate(problem.descriptor): - label = descriptor.label - if self.filled_dataset_descriptors[label] is not None: - start = sum(problem.data_sizes[0:i]) - end = start + problem.data_sizes[i] - matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale - - clp, residual = self._residual_function(matrix, data) - return clp, residual, residual / problem.weight - - results = list(map(residual_function, self.bag, self.reduced_matrices)) + results = ( + list( + map( + self._index_dependent_residual, + self.bag, + self.reduced_matrices, + self._group_clp_labels, + self._full_axis, + ) + ) + if self._index_dependent + else list(map(self._index_independent_residual, self.bag, self._full_axis)) + ) - self._weighted_residuals = list(map(lambda result: result[1], results)) - self._residuals = list(map(lambda result: result[2], results)) + clps = xr.concat(list(map(lambda result: result[0], results)), dim=self._global_dimension) + clps.coords[self._global_dimension] = self._full_axis + reduced_clps = xr.concat( + list(map(lambda result: result[1], results)), dim=self._global_dimension + ) + reduced_clps.coords[self._global_dimension] = self._full_axis + self._ungroup_clps(clps, reduced_clps) - reduced_clps = list(map(lambda result: result[0], results)) - self._ungroup_clps(reduced_clps) + self._weighted_residuals = list(map(lambda result: result[2], results)) + self._residuals = list(map(lambda result: result[3], results)) + self._additional_penalty = calculate_clp_penalties( + self.model, self.parameters, clps, self._global_dimension + ) return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals - def calculate_index_independent_residual( + def _index_dependent_residual( self, - ) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray], list[np.ndarray],]: - """Calculates the index independent residuals.""" - - def residual_function(problem: ProblemGroup): - matrix = self.reduced_matrices[problem.group].copy() - for i in range(matrix.shape[1]): - matrix[:, i] *= problem.weight - data = problem.data - if problem.has_scaling: - for i, descriptor in enumerate(problem.descriptor): - label = descriptor.label - if self.filled_dataset_descriptors[label] is not None: - start = sum(problem.data_sizes[0:i]) - end = start + problem.data_sizes[i] - matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale - clp, residual = self._residual_function(matrix, data) - return clp, residual, residual / problem.weight - - results = list(map(residual_function, self.bag)) - - self._weighted_residuals = list(map(lambda result: result[1], results)) - self._residuals = list(map(lambda result: result[2], results)) - - reduced_clps = list(map(lambda result: result[0], results)) - self._ungroup_clps(reduced_clps) - - return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals + problem: ProblemGroup, + matrix: np.ndarray, + group_clp_labels: str, + index: any, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + + matrix = matrix.copy() + for i in range(matrix.shape[1]): + matrix[:, i] *= problem.weight + data = problem.data + if problem.has_scaling: + for i, descriptor in enumerate(problem.descriptor): + label = descriptor.label + if self.filled_dataset_descriptors[label] is not None: + start = sum(problem.data_sizes[0:i]) + end = start + problem.data_sizes[i] + matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale + + reduced_clps, residual = self._residual_function(matrix.values, data) + reduced_clps = xr.DataArray( + reduced_clps, dims=["clp_label"], coords={"clp_label": matrix.coords["clp_label"]} + ) + clps = retrieve_clps( + self.model, + self.parameters, + group_clp_labels, + reduced_clps, + index, + ) + return clps, reduced_clps, residual, residual / problem.weight + + def _index_independent_residual(self, problem: ProblemGroup, index: any): + matrix = self.reduced_matrices[problem.group].copy() + for i in range(matrix.shape[1]): + matrix[:, i] *= problem.weight + data = problem.data + if problem.has_scaling: + for i, descriptor in enumerate(problem.descriptor): + label = descriptor.label + if self.filled_dataset_descriptors[label] is not None: + start = sum(problem.data_sizes[0:i]) + end = start + problem.data_sizes[i] + matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale + reduced_clps, residual = self._residual_function(matrix.values, data) + reduced_clps = xr.DataArray( + reduced_clps, dims=["clp_label"], coords={"clp_label": matrix.coords["clp_label"]} + ) + clps = retrieve_clps( + self.model, + self.parameters, + self._group_clp_labels[problem.group], + reduced_clps, + index, + ) + return clps, reduced_clps, residual, residual / problem.weight - def _ungroup_clps(self, reduced_clps: np.ndarray): - reduced_clp_labels = self.reduced_clp_labels - self._reduced_clp_labels = {} + def _ungroup_clps(self, clps: xr.DataArray, reduced_clps: xr.DataArray): self._reduced_clps = {} - for label, clp_labels in self.clp_labels.items(): + self._clps = {} + for label in self.matrices: + clp_labels = ( + [m.coords["clp_label"] for m in self.matrices[label]] + if self._index_dependent + else self.matrices[label].coords["clp_label"] + ) # find offset in the full axis offset = find_closest_index( self.data[label].coords[self._global_dimension][0].values, self._full_axis ) - self._reduced_clp_labels[label] = [] self._reduced_clps[label] = [] + self._clps[label] = [] + for i in range(self.data[label].coords[self._global_dimension].size): - group_label = self.bag[i].group - dataset_clp_labels = clp_labels[i] if self._index_dependent else clp_labels - index_clp_labels = ( - reduced_clp_labels[i + offset] - if self._index_dependent - else reduced_clp_labels[group_label] + + index_clp_labels = clp_labels[i] if self._index_dependent else clp_labels + index_clps = clps[i + offset] + index_clps = index_clps.sel({"clp_label": index_clp_labels}) + self._clps[label].append(index_clps) + + index_reduced_clps = reduced_clps[i + offset] + index_reduced_clp_labels, _ = xr.align( + index_clp_labels, index_reduced_clps.coords["clp_label"] ) - self._reduced_clp_labels[label].append( - [ - clp_label - for clp_label in dataset_clp_labels - if clp_label in index_clp_labels - ] + index_reduced_clps = index_reduced_clps.sel( + {"clp_label": index_reduced_clp_labels} ) + self._reduced_clps[label].append(index_reduced_clps) - mask = [ - clp_label in self._reduced_clp_labels[label][i] - for clp_label in index_clp_labels - ] - self._reduced_clps[label].append(reduced_clps[i + offset][mask]) - self._clps = ( - self.model.retrieve_clp_function( - self.parameters, - self.clp_labels, - self.reduced_clp_labels, - self.reduced_clps, - self.data, + self._reduced_clps[label] = xr.concat( + self.reduced_clps[label], dim=self._global_dimension ) - if callable(self.model.retrieve_clp_function) - else self._reduced_clps - ) + self._reduced_clps[label].coords[self._global_dimension] = self.data[label].coords[ + self._global_dimension + ] + + self._clps[label] = xr.concat(self._clps[label], dim=self._global_dimension) + self._clps[label].coords[self._global_dimension] = self.data[label].coords[ + self._global_dimension + ] def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: """Creates a result datasets for index dependent matrices.""" @@ -431,9 +416,6 @@ def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) dataset, grouped_problem, index, group_index, global_index ) - # we assume that the labels are the same, this might not be true in - # future models - dataset.coords["clp_label"] = self.clp_labels[label][0] dataset["matrix"] = ( ( (self._global_dimension), @@ -442,13 +424,7 @@ def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) ), self.matrices[label], ) - dataset["clp"] = ( - ( - (self._global_dimension), - ("clp_label"), - ), - self.clps[label], - ) + dataset["clp"] = self.clps[label] return dataset @@ -457,7 +433,8 @@ def create_index_independent_result_dataset( ) -> xr.Dataset: """Creates a result datasets for index independent matrices.""" - self._add_index_independent_matrix_to_dataset(label, dataset) + dataset["matrix"] = self.matrices[label] + dataset["clp"] = self.clps[label] for index, grouped_problem in enumerate(self.bag): @@ -473,26 +450,8 @@ def create_index_independent_result_dataset( dataset, grouped_problem, index, group_index, global_index ) - dataset["clp"] = ( - ( - (self._global_dimension), - ("clp_label"), - ), - self.clps[label], - ) - return dataset - def _add_index_independent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): - dataset.coords["clp_label"] = self.clp_labels[label] - dataset["matrix"] = ( - ( - (self._model_dimension), - ("clp_label"), - ), - np.asarray(self.matrices[label]), - ) - def _add_grouped_residual_to_dataset( self, dataset: xr.Dataset, @@ -525,3 +484,16 @@ def _add_grouped_residual_to_dataset( dataset.residual.loc[{self._global_dimension: global_index}] = self.residuals[index][ start:end ] + + @property + def full_penalty(self) -> np.ndarray: + if self._full_penalty is None: + residuals = self.weighted_residuals + additional_penalty = self.additional_penalty + + self._full_penalty = ( + np.concatenate((np.concatenate(residuals), additional_penalty)) + if additional_penalty is not None + else np.concatenate(residuals) + ) + return self._full_penalty diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 62d4422c5..38a8d522b 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -6,8 +6,10 @@ from glotaran.analysis.problem import ParameterError from glotaran.analysis.problem import Problem from glotaran.analysis.problem import UngroupedProblemDescriptor +from glotaran.analysis.util import calculate_clp_penalties from glotaran.analysis.util import calculate_matrix from glotaran.analysis.util import reduce_matrix +from glotaran.analysis.util import retrieve_clps from glotaran.model import DatasetDescriptor @@ -35,25 +37,17 @@ def init_bag(self): def calculate_matrices( self, ) -> tuple[ - dict[str, list[list[str]] | list[str]], - dict[str, list[np.ndarray] | np.ndarray], - dict[str, list[str]], - dict[str, list[np.ndarray] | np.ndarray], + dict[str, list[xr.DataArray] | xr.DataArray], + dict[str, list[xr.DataArray] | xr.DataArray], ]: """Calculates the model matrices.""" if self._parameters is None: raise ParameterError - self._clp_labels = {} self._matrices = {} - self._reduced_clp_labels = {} self._reduced_matrices = {} for label, problem in self.bag.items(): - self._clp_labels[label] = [] - self._matrices[label] = [] - self._reduced_clp_labels[label] = [] - self._reduced_matrices[label] = [] dataset_model = self._filled_dataset_descriptors[label] if dataset_model.index_dependent(): @@ -61,51 +55,37 @@ def calculate_matrices( else: self._calculate_index_independent_matrix(label, problem, dataset_model) - return self._clp_labels, self._matrices, self._reduced_clp_labels, self._reduced_matrices + return self._matrices, self._reduced_matrices def _calculate_index_dependent_matrix( self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetDescriptor ): + self._matrices[label] = [] + self._reduced_matrices[label] = [] for i, index in enumerate(problem.global_axis): - result = calculate_matrix( - self._model, + matrix = calculate_matrix( dataset_model, {dataset_model.get_global_dimension(): i}, - { - dataset_model.get_model_dimension(): problem.model_axis, - dataset_model.get_global_dimension(): problem.global_axis, - }, ) - - self._clp_labels[label].append(result.clp_label) - self._matrices[label].append(result.matrix) - reduced_labels_and_matrix = reduce_matrix( - self._model, label, self._parameters, result, index + self._matrices[label].append(matrix) + reduced_matrix = reduce_matrix( + matrix, self.model, self.parameters, dataset_model.get_model_dimension(), index ) - self._reduced_clp_labels[label].append(reduced_labels_and_matrix.clp_label) - self._reduced_matrices[label].append(reduced_labels_and_matrix.matrix) + self._reduced_matrices[label].append(reduced_matrix) def _calculate_index_independent_matrix( self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetDescriptor ): - model_dimension = dataset_model.get_model_dimension() - global_dimension = dataset_model.get_global_dimension() - result = calculate_matrix( - self._model, + matrix = calculate_matrix( dataset_model, {}, - { - model_dimension: problem.model_axis, - global_dimension: problem.global_axis, - }, ) - - self._clp_labels[label] = result.clp_label - self._matrices[label] = result.matrix - reduced_result = reduce_matrix(self._model, label, self._parameters, result, None) - self._reduced_clp_labels[label] = reduced_result.clp_label - self._reduced_matrices[label] = reduced_result.matrix + self._matrices[label] = matrix + reduced_matrix = reduce_matrix( + matrix, self.model, self.parameters, dataset_model.get_model_dimension(), None + ) + self._reduced_matrices[label] = reduced_matrix def calculate_residual( self, @@ -118,51 +98,64 @@ def calculate_residual( """Calculates the residuals.""" self._reduced_clps = {} + self._clps = {} self._weighted_residuals = {} self._residuals = {} + self._additional_penalty = [] for label, problem in self.bag.items(): self._calculate_residual_for_problem(label, problem) - self._clps = ( - self.model.retrieve_clp_function( - self.parameters, - self.clp_labels, - self.reduced_clp_labels, - self.reduced_clps, - self.data, - ) - if callable(self.model.retrieve_clp_function) - else self.reduced_clps + self._additional_penalty = ( + np.concatenate(self._additional_penalty) if len(self._additional_penalty) != 0 else [] ) - return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals def _calculate_residual_for_problem(self, label: str, problem: UngroupedProblemDescriptor): self._reduced_clps[label] = [] + self._clps[label] = [] self._weighted_residuals[label] = [] self._residuals[label] = [] data = problem.data dataset_model = self._filled_dataset_descriptors[label] + model_dimension = dataset_model.get_model_dimension() global_dimension = dataset_model.get_global_dimension() - for i in range(len(problem.global_axis)): - matrix = ( + for i, index in enumerate(problem.global_axis): + clp_labels = ( + self.matrices[label][i].coords["clp_label"] + if dataset_model.index_dependent() + else self.matrices[label].coords["clp_label"] + ) + reduced_matrix = ( self.reduced_matrices[label][i] if dataset_model.index_dependent() - else self.reduced_matrices[label].copy() - ) # TODO: .copy() or not + else self.reduced_matrices[label] + ) if problem.dataset.scale is not None: - matrix *= self.filled_dataset_descriptors[label].scale + reduced_matrix *= self.filled_dataset_descriptors[label].scale if problem.weight is not None: - for j in range(matrix.shape[1]): - matrix[:, j] *= problem.weight.isel({global_dimension: i}).values + for j in range(reduced_matrix.shape[1]): + reduced_matrix[:, j] *= problem.weight.isel({global_dimension: i}).values - clp, residual = self._residual_function( - matrix, data.isel({global_dimension: i}).values + reduced_clps, residual = self._residual_function( + reduced_matrix.values, data.isel({global_dimension: i}).values + ) + reduced_clps = xr.DataArray( + reduced_clps, + dims=["clp_label"], + coords={"clp_label": reduced_matrix.coords["clp_label"]}, + ) + self._reduced_clps[label].append(reduced_clps) + self._clps[label].append( + retrieve_clps(self.model, self.parameters, clp_labels, reduced_clps, index) + ) + residual = xr.DataArray( + residual, + dims=[model_dimension], + coords={model_dimension: reduced_matrix.coords[model_dimension]}, ) - self._reduced_clps[label].append(clp) self._weighted_residuals[label].append(residual) if problem.weight is not None: self._residuals[label].append( @@ -171,6 +164,16 @@ def _calculate_residual_for_problem(self, label: str, problem: UngroupedProblemD else: self._residuals[label].append(residual) + self._reduced_clps[label] = xr.concat(self._reduced_clps[label], dim=global_dimension) + self._reduced_clps[label].coords[global_dimension] = data.coords[global_dimension] + self._clps[label] = xr.concat(self._clps[label], dim=global_dimension) + self._clps[label].coords[global_dimension] = data.coords[global_dimension] + additional_penalty = calculate_clp_penalties( + self.model, self.parameters, self._clps[label], global_dimension + ) + if additional_penalty.size != 0: + self._additional_penalty.append(additional_penalty) + def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: """Creates a result datasets for index dependent matrices.""" @@ -192,53 +195,62 @@ def create_index_independent_result_dataset( return dataset def _add_index_dependent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): - # we assume that the labels are the same, this might not be true in - # future models - dataset.coords["clp_label"] = self.clp_labels[label][0] model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() global_dimension = self.filled_dataset_descriptors[label].get_global_dimension() + matrix = xr.concat(self.matrices[label], dim=global_dimension) + matrix.coords[global_dimension] = dataset.coords[global_dimension] + dataset.coords["clp_label"] = matrix.coords["clp_label"] dataset["matrix"] = ( ( (global_dimension), (model_dimension), ("clp_label"), ), - np.asarray(self.matrices[label]), + matrix.data, ) def _add_index_independent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): - dataset.coords["clp_label"] = self.clp_labels[label] + dataset.coords["clp_label"] = self.matrices[label].coords["clp_label"] model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() dataset["matrix"] = ( ( (model_dimension), ("clp_label"), ), - np.asarray(self.matrices[label]), + self.matrices[label].data, ) def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset): model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() global_dimension = self.filled_dataset_descriptors[label].get_global_dimension() - dataset["clp"] = ( - ( - (global_dimension), - ("clp_label"), - ), - np.asarray(self.clps[label]), - ) + dataset["clp"] = self.clps[label] dataset["weighted_residual"] = ( ( (model_dimension), (global_dimension), ), - np.transpose(np.asarray(self.weighted_residuals[label])), + xr.concat(self.weighted_residuals[label], dim=global_dimension).T.data, ) dataset["residual"] = ( ( (model_dimension), (global_dimension), ), - np.transpose(np.asarray(self.residuals[label])), + xr.concat(self.residuals[label], dim=global_dimension).T.data, ) + + @property + def full_penalty(self) -> np.ndarray: + if self._full_penalty is None: + residuals = self.weighted_residuals + additional_penalty = self.additional_penalty + print(residuals) + residuals = [np.concatenate(residuals[label]) for label in residuals.keys()] + + self._full_penalty = ( + np.concatenate((np.concatenate(residuals), additional_penalty)) + if additional_penalty is not None + else np.concatenate(residuals) + ) + return self._full_penalty diff --git a/glotaran/analysis/simulation.py b/glotaran/analysis/simulation.py index 22a05f336..c8efb3492 100644 --- a/glotaran/analysis/simulation.py +++ b/glotaran/analysis/simulation.py @@ -63,47 +63,25 @@ def simulate( result = xr.DataArray( data=0.0, coords=[ - (model_dimension, model_axis), - (global_dimension, global_axis), + (model_dimension, model_axis.data), + (global_dimension, global_axis.data), ], ) result = result.to_dataset(name="data") + filled_dataset.set_data(result) matrix = ( [ calculate_matrix( - model, filled_dataset, {global_dimension: index}, - {model_dimension: model_axis, global_dimension: global_axis}, ) for index, _ in enumerate(global_axis) ] if filled_dataset.index_dependent() else calculate_matrix( - model, filled_dataset, {}, - {model_dimension: model_axis, global_dimension: global_axis}, - ) - ) - if callable(model.constrain_matrix_function): - matrix = ( - [ - model.constrain_matrix_function(dataset, parameters, clp, mat, global_axis[i]) - for i, (clp, mat) in enumerate(matrix) - ] - if filled_dataset.index_dependent() - else model.constrain_matrix_function(dataset, parameters, matrix[0], matrix[1], None) - ) - matrix = ( - [ - xr.DataArray(mat, coords=[(model_dimension, model_axis), ("clp_label", clp_label)]) - for clp_label, mat in matrix - ] - if filled_dataset.index_dependent() - else xr.DataArray( - matrix[1], coords=[(model_dimension, model_axis), ("clp_label", matrix[0])] ) ) @@ -135,6 +113,7 @@ def simulate( ) for i in range(global_axis.size): index_matrix = matrix[i] if filled_dataset.index_dependent() else matrix + print(index_matrix.coords) result.data[:, i] = np.dot( index_matrix, clp[i].sel(clp_label=index_matrix.coords["clp_label"]) ) diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 9d9cd19b8..5d9d409b1 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -29,7 +29,8 @@ def calculate_e(dataset, axis): @megacomplex("c", properties={"is_index_dependent": bool}) class SimpleTestMegacomplex(Megacomplex): - def calculate_matrix(self, model, dataset_descriptor, indices, axis, **kwargs): + def calculate_matrix(self, dataset_model, indices, **kwargs): + axis = dataset_model.get_data().coords assert "c" in axis assert "e" in axis @@ -42,7 +43,7 @@ def calculate_matrix(self, model, dataset_descriptor, indices, axis, **kwargs): r_compartments.append(compartments[i]) for j in range(axis.shape[0]): array[j, i] = (i + j) * axis[j] - return (r_compartments, array) + return xr.DataArray(array, coords=(("c", axis.data), ("clp_label", r_compartments))) def index_dependent(self, dataset_model): return self.is_index_dependent @@ -62,18 +63,19 @@ class SimpleTestModel(Model): @megacomplex("c", properties={"is_index_dependent": bool}) class SimpleKineticMegacomplex(Megacomplex): - def calculate_matrix(self, model, dataset_descriptor, indices, axis, **kwargs): + def calculate_matrix(self, dataset_model, indices, **kwargs): + axis = dataset_model.get_data().coords assert "c" in axis assert "e" in axis axis = axis["c"] - kinpar = -1 * np.asarray(dataset_descriptor.kinetic) - if dataset_descriptor.label == "dataset3": + kinpar = -1 * np.asarray(dataset_model.kinetic) + if dataset_model.label == "dataset3": # this case is for the ThreeDatasetDecay test compartments = [f"s{i+2}" for i in range(len(kinpar))] else: compartments = [f"s{i+1}" for i in range(len(kinpar))] array = np.exp(np.outer(axis, kinpar)) - return (compartments, array) + return xr.DataArray(array, coords=(("c", axis.data), ("clp_label", compartments))) def index_dependent(self, dataset_model): return self.is_index_dependent @@ -235,11 +237,11 @@ class GaussianShapeDecayDatasetDescriptor(DatasetDescriptor): global_matrix=calculate_spectral_simple, global_dimension="e", megacomplex_types=SimpleKineticMegacomplex, - has_additional_penalty_function=lambda model: True, - additional_penalty_function=additional_penalty_typecheck, - has_matrix_constraints_function=lambda model: True, - constrain_matrix_function=constrain_matrix_function_typecheck, - retrieve_clp_function=retrieve_clp_typecheck, + # has_additional_penalty_function=lambda model: True, + # additional_penalty_function=additional_penalty_typecheck, + # has_matrix_constraints_function=lambda model: True, + # constrain_matrix_function=constrain_matrix_function_typecheck, + # retrieve_clp_function=retrieve_clp_typecheck, grouped=lambda model: model.is_grouped, ) class DecayModel(Model): @@ -258,8 +260,8 @@ class DecayModel(Model): global_dimension="e", megacomplex_types=SimpleKineticMegacomplex, grouped=lambda model: model.is_grouped, - has_additional_penalty_function=lambda model: True, - additional_penalty_function=additional_penalty_typecheck, + # has_additional_penalty_function=lambda model: True, + # additional_penalty_function=additional_penalty_typecheck, ) class GaussianDecayModel(Model): additional_penalty_function_called = False diff --git a/glotaran/analysis/test/test_constraints.py b/glotaran/analysis/test/test_constraints.py new file mode 100644 index 000000000..c83eaf88a --- /dev/null +++ b/glotaran/analysis/test/test_constraints.py @@ -0,0 +1,44 @@ +from copy import deepcopy + +import pytest + +from glotaran.analysis.problem_grouped import GroupedProblem +from glotaran.analysis.problem_ungrouped import UngroupedProblem +from glotaran.analysis.simulation import simulate +from glotaran.analysis.test.models import TwoCompartmentDecay as suite +from glotaran.model import ZeroConstraint +from glotaran.project import Scheme + + +@pytest.mark.parametrize("index_dependent", [True, False]) +@pytest.mark.parametrize("grouped", [True, False]) +def test_constraint(index_dependent, grouped): + model = deepcopy(suite.model) + model.megacomplex["m1"].is_index_dependent = index_dependent + model.constraints.append(ZeroConstraint.from_dict({"target": "s2"})) + + print("grouped", grouped, "index_dependent", index_dependent) + dataset = simulate( + suite.sim_model, + "dataset1", + suite.wanted_parameters, + {"e": suite.e_axis, "c": suite.c_axis}, + ) + scheme = Scheme(model=model, parameters=suite.initial_parameters, data={"dataset1": dataset}) + problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) + + reduced_clps = problem.reduced_clps["dataset1"] + if index_dependent: + reduced_matrix = ( + problem.reduced_matrices[0] if grouped else problem.reduced_matrices["dataset1"][0] + ) + else: + reduced_matrix = problem.reduced_matrices["dataset1"] + matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] + clps = problem.clps["dataset1"] + + assert "s2" not in reduced_clps.coords["clp_label"] + assert "s2" not in reduced_matrix.coords["clp_label"] + assert "s2" in clps.coords["clp_label"] + assert clps.sel(clp_label="s2") == 0 + assert "s2" in matrix.coords["clp_label"] diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 28209a6be..6f3acc4db 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -4,7 +4,6 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate -from glotaran.analysis.test.models import DecayModel from glotaran.analysis.test.models import MultichannelMulticomponentDecay from glotaran.analysis.test.models import OneCompartmentDecay from glotaran.analysis.test.models import ThreeDatasetDecay @@ -133,14 +132,14 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert "weighted_residual_right_singular_vectors" in resultdata assert "weighted_residual_singular_values" in resultdata - assert callable(model.additional_penalty_function) - assert model.additional_penalty_function_called - - if isinstance(model, DecayModel): - assert callable(model.constrain_matrix_function) - assert model.constrain_matrix_function_called - assert callable(model.retrieve_clp_function) - assert model.retrieve_clp_function_called - else: - assert not model.constrain_matrix_function_called - assert not model.retrieve_clp_function_called + # assert callable(model.additional_penalty_function) + # assert model.additional_penalty_function_called + # + # if isinstance(model, DecayModel): + # assert callable(model.constrain_matrix_function) + # assert model.constrain_matrix_function_called + # assert callable(model.retrieve_clp_function) + # assert model.retrieve_clp_function_called + # else: + # assert not model.constrain_matrix_function_called + # assert not model.retrieve_clp_function_called diff --git a/glotaran/analysis/test/test_penalties.py b/glotaran/analysis/test/test_penalties.py new file mode 100644 index 000000000..72019a0d7 --- /dev/null +++ b/glotaran/analysis/test/test_penalties.py @@ -0,0 +1,53 @@ +from copy import deepcopy + +import numpy as np +import pytest + +from glotaran.analysis.problem_grouped import GroupedProblem +from glotaran.analysis.problem_ungrouped import UngroupedProblem +from glotaran.analysis.simulation import simulate +from glotaran.analysis.test.models import TwoCompartmentDecay as suite +from glotaran.model import EqualAreaPenalty +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + + +@pytest.mark.parametrize("index_dependent", [True, False]) +@pytest.mark.parametrize("grouped", [True, False]) +def test_constraint(index_dependent, grouped): + model = deepcopy(suite.model) + model.megacomplex["m1"].is_index_dependent = index_dependent + model.clp_area_penalties.append( + EqualAreaPenalty.from_dict( + { + "source": "s1", + "source_intervals": [(1, 20)], + "target": "s2", + "target_intervals": [(20, 45)], + "parameter": "3", + "weight": 10, + } + ) + ) + parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) + + e_axis = np.arange(50) + + print("grouped", grouped, "index_dependent", index_dependent) + dataset = simulate( + suite.sim_model, + "dataset1", + parameters, + {"e": e_axis, "c": suite.c_axis}, + ) + scheme = Scheme(model=model, parameters=parameters, data={"dataset1": dataset}) + problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) + + assert isinstance(problem.additional_penalty, np.ndarray) + assert problem.additional_penalty.size == 1 + assert problem.additional_penalty[0] != 0 + assert isinstance(problem.full_penalty, np.ndarray) + assert ( + problem.full_penalty.size + == (suite.c_axis.size * e_axis.size) + problem.additional_penalty.size + ) diff --git a/glotaran/analysis/test/test_problem.py b/glotaran/analysis/test/test_problem.py index 2de74850e..30d3c60c4 100644 --- a/glotaran/analysis/test/test_problem.py +++ b/glotaran/analysis/test/test_problem.py @@ -51,33 +51,25 @@ def test_problem_matrices(problem: Problem): if problem.grouped: if problem.model.is_index_dependent: - assert all(isinstance(m, list) for m in problem.reduced_clp_labels) - assert all(isinstance(m, np.ndarray) for m in problem.reduced_matrices) - assert len(problem.reduced_clp_labels) == suite.e_axis.size + assert all(isinstance(m, xr.DataArray) for m in problem.reduced_matrices) assert len(problem.reduced_matrices) == suite.e_axis.size else: - assert "dataset1" in problem.reduced_clp_labels assert "dataset1" in problem.reduced_matrices - assert isinstance(problem.reduced_clp_labels["dataset1"], list) - assert isinstance(problem.reduced_matrices["dataset1"], np.ndarray) + assert isinstance(problem.reduced_matrices["dataset1"], xr.DataArray) else: if problem.model.is_index_dependent: - assert isinstance(problem.reduced_clp_labels, dict) assert isinstance(problem.reduced_matrices, dict) assert isinstance(problem.reduced_matrices["dataset1"], list) - assert all(isinstance(c, list) for c in problem.reduced_clp_labels["dataset1"]) - assert all(isinstance(m, np.ndarray) for m in problem.reduced_matrices["dataset1"]) + assert all(isinstance(m, xr.DataArray) for m in problem.reduced_matrices["dataset1"]) else: - assert isinstance(problem.reduced_matrices["dataset1"], np.ndarray) + assert isinstance(problem.reduced_matrices["dataset1"], xr.DataArray) - assert isinstance(problem.clp_labels, dict) assert isinstance(problem.matrices, dict) - assert isinstance(problem.reduced_clp_labels["dataset1"], list) - assert "dataset1" in problem.reduced_clp_labels assert "dataset1" in problem.reduced_matrices def test_problem_residuals(problem: Problem): + print("Grouped", problem.model.is_grouped, "Indexdep", problem.model.is_index_dependent) problem.calculate_residual() if problem.grouped: assert isinstance(problem.residuals, list) @@ -86,28 +78,21 @@ def test_problem_residuals(problem: Problem): else: assert isinstance(problem.residuals, dict) assert "dataset1" in problem.residuals - assert all(isinstance(r, np.ndarray) for r in problem.residuals["dataset1"]) + assert all(isinstance(r, xr.DataArray) for r in problem.residuals["dataset1"]) assert len(problem.residuals["dataset1"]) == suite.e_axis.size assert isinstance(problem.reduced_clps, dict) assert "dataset1" in problem.reduced_clps - assert all(isinstance(c, np.ndarray) for c in problem.reduced_clps["dataset1"]) + assert all(isinstance(c, xr.DataArray) for c in problem.reduced_clps["dataset1"]) assert len(problem.reduced_clps["dataset1"]) == suite.e_axis.size assert isinstance(problem.clps, dict) assert "dataset1" in problem.clps - assert all(isinstance(c, np.ndarray) for c in problem.clps["dataset1"]) + assert all(isinstance(c, xr.DataArray) for c in problem.clps["dataset1"]) assert len(problem.clps["dataset1"]) == suite.e_axis.size - assert isinstance(problem.additional_penalty, np.ndarray) - assert problem.additional_penalty.size == 1 - assert problem.additional_penalty[0] == 0.1 - assert isinstance(problem.full_penalty, np.ndarray) - assert ( - problem.full_penalty.size - == (suite.c_axis.size * suite.e_axis.size) + problem.additional_penalty.size - ) def test_problem_result_data(problem: Problem): + print("Grouped", problem.model.is_grouped, "Indexdep", problem.model.is_index_dependent) data = problem.create_result_data() label = "dataset1" diff --git a/glotaran/analysis/test/test_relations.py b/glotaran/analysis/test/test_relations.py new file mode 100644 index 000000000..2cdbde279 --- /dev/null +++ b/glotaran/analysis/test/test_relations.py @@ -0,0 +1,46 @@ +from copy import deepcopy + +import pytest + +from glotaran.analysis.problem_grouped import GroupedProblem +from glotaran.analysis.problem_ungrouped import UngroupedProblem +from glotaran.analysis.simulation import simulate +from glotaran.analysis.test.models import TwoCompartmentDecay as suite +from glotaran.model import Relation +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + + +@pytest.mark.parametrize("index_dependent", [True, False]) +@pytest.mark.parametrize("grouped", [True, False]) +def test_constraint(index_dependent, grouped): + model = deepcopy(suite.model) + model.megacomplex["m1"].is_index_dependent = index_dependent + model.relations.append(Relation.from_dict({"source": "s1", "target": "s2", "parameter": "3"})) + parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) + + print("grouped", grouped, "index_dependent", index_dependent) + dataset = simulate( + suite.sim_model, + "dataset1", + parameters, + {"e": suite.e_axis, "c": suite.c_axis}, + ) + scheme = Scheme(model=model, parameters=parameters, data={"dataset1": dataset}) + problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) + + reduced_clps = problem.reduced_clps["dataset1"] + if index_dependent: + reduced_matrix = ( + problem.reduced_matrices[0] if grouped else problem.reduced_matrices["dataset1"][0] + ) + else: + reduced_matrix = problem.reduced_matrices["dataset1"] + matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] + clps = problem.clps["dataset1"] + + assert "s2" not in reduced_clps.coords["clp_label"] + assert "s2" not in reduced_matrix.coords["clp_label"] + assert "s2" in clps.coords["clp_label"] + assert clps.sel(clp_label="s2") == clps.sel(clp_label="s1") * 2 + assert "s2" in matrix.coords["clp_label"] diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index b4818eec8..f1ee25df9 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -1,20 +1,16 @@ from __future__ import annotations import itertools -from typing import NamedTuple +from typing import Any import numpy as np +import xarray as xr from glotaran.model import DatasetDescriptor from glotaran.model import Model from glotaran.parameter import ParameterGroup -class LabelAndMatrix(NamedTuple): - clp_label: list[str] - matrix: np.ndarray - - def find_overlap(a, b, rtol=1e-05, atol=1e-08): ovr_a = [] ovr_b = [] @@ -44,79 +40,165 @@ def get_min_max_from_interval(interval, axis): def calculate_matrix( - model: Model, dataset_descriptor: DatasetDescriptor, - indices: dict[str, int], - axis: dict[str, np.ndarray], -) -> LabelAndMatrix: - clp_labels = None + indices: dict[str, int] | None, +) -> xr.DataArray: matrix = None for scale, megacomplex in dataset_descriptor.iterate_megacomplexes(): - this_clp_labels, this_matrix = megacomplex.calculate_matrix( - model, dataset_descriptor, indices, axis - ) + this_matrix = megacomplex.calculate_matrix(dataset_descriptor, indices) if scale is not None: this_matrix *= scale if matrix is None: - clp_labels = this_clp_labels matrix = this_matrix else: - tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] - tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) - for idx, label in enumerate(tmp_clp_labels): - if label in clp_labels: - tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] - if label in this_clp_labels: - tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] - clp_labels = tmp_clp_labels - matrix = tmp_matrix + matrix, this_matrix = xr.align(matrix, this_matrix, join="outer", copy=False) + matrix = matrix.fillna(0) + matrix += this_matrix.fillna(0) - return LabelAndMatrix(clp_labels, matrix) + return matrix def reduce_matrix( + matrix: xr.DataArray, model: Model, - label: str, parameters: ParameterGroup, - result: LabelAndMatrix, - index: float | None, -) -> LabelAndMatrix: - clp_labels = result.clp_label.copy() - if callable(model.has_matrix_constraints_function) and model.has_matrix_constraints_function(): - clp_label, matrix = model.constrain_matrix_function( - label, parameters, clp_labels, result.matrix, index - ) - return LabelAndMatrix(clp_label, matrix) - return LabelAndMatrix(clp_labels, result.matrix) - - -def combine_matrices(labels_and_matrices: list[LabelAndMatrix]) -> LabelAndMatrix: - masks = [] - full_clp_labels = None - sizes = [] - for label_and_matrix in labels_and_matrices: - (clp_label, matrix) = label_and_matrix - sizes.append(matrix.shape[0]) - if full_clp_labels is None: - full_clp_labels = clp_label - masks.append([i for i, _ in enumerate(clp_label)]) - else: - mask = [] - for c in clp_label: - if c not in full_clp_labels: - full_clp_labels.append(c) - mask.append(full_clp_labels.index(c)) - masks.append(mask) - dim1 = np.sum(sizes) - dim2 = len(full_clp_labels) - full_matrix = np.zeros((dim1, dim2), dtype=np.float64) - start = 0 - for i, m in enumerate(labels_and_matrices): - end = start + sizes[i] - full_matrix[start:end, masks[i]] = m[1] - start = end - - return LabelAndMatrix(full_clp_labels, full_matrix) + model_dimension: str, + index: Any | None, +) -> xr.DataArray: + matrix = apply_relations(matrix, model, parameters, model_dimension, index) + matrix = apply_constraints(matrix, model, index) + return matrix + + +def apply_constraints( + matrix: xr.DataArray, + model: Model, + index: Any | None, +) -> xr.DataArray: + + if len(model.constraints) == 0: + return matrix + + clp_labels = matrix.coords["clp_label"].values + removed_clp = [ + c.target for c in model.constraints if c.target in clp_labels and c.applies(index) + ] + reduced_clp_label = [c for c in clp_labels if c not in removed_clp] + + return matrix.sel({"clp_label": reduced_clp_label}) + + +def apply_relations( + matrix: xr.DataArray, + model: Model, + parameters: ParameterGroup, + model_dimension: str, + index: Any | None, +) -> xr.DataArray: + + if len(model.relations) == 0: + return matrix + + clp_labels = list(matrix.coords["clp_label"].values) + relation_matrix = np.diagflat([1.0 for _ in clp_labels]) + + idx_to_delete = [] + for relation in model.relations: + if relation.target in clp_labels and relation.applies(index): + + if relation.source not in clp_labels: + continue + + relation = relation.fill(model, parameters) + source_idx = clp_labels.index(relation.source) + target_idx = clp_labels.index(relation.target) + relation_matrix[target_idx, source_idx] = relation.parameter + idx_to_delete.append(target_idx) + + reduced_clp_label = [label for i, label in enumerate(clp_labels) if i not in idx_to_delete] + relation_matrix = np.delete(relation_matrix, idx_to_delete, axis=1) + return xr.DataArray( + matrix.values @ relation_matrix, + dims=matrix.dims, + coords={ + "clp_label": reduced_clp_label, + model_dimension: matrix.coords[model_dimension], + }, + ) + + +def retrieve_clps( + model: Model, + parameters: ParameterGroup, + clp_labels: xr.DataArray, + reduced_clps: xr.DataArray, + index: Any | None, +) -> xr.DataArray: + if len(model.relations) == 0 and len(model.constraints) == 0: + return reduced_clps + + clps = xr.DataArray(np.zeros((clp_labels.size), dtype=np.float64), coords=[clp_labels]) + clps.loc[{"clp_label": reduced_clps.coords["clp_label"]}] = reduced_clps.values + + print("ret", clps) + for relation in model.relations: + relation = relation.fill(model, parameters) + print("YYY", relation.target, relation.source, relation.parameter) + if relation.target in clp_labels and relation.applies(index): + if relation.source not in clp_labels: + continue + clps.loc[{"clp_label": relation.target}] = relation.parameter * clps.sel( + clp_label=relation.source + ) + + return clps + + +def calculate_clp_penalties( + model: Model, + parameters: ParameterGroup, + clps: xr.DataArray, + global_dimension: str, +) -> np.ndarray: + + penalties = [] + for penalty in model.clp_area_penalties: + if ( + penalty.source in clps.coords["clp_label"] + and penalty.target in clps.coords["clp_label"] + ): + penalty = penalty.fill(model, parameters) + + source_area = xr.concat( + [ + clps.sel( + { + "clp_label": penalty.source, + global_dimension: slice(interval[0], interval[1]), + } + ) + for interval in penalty.source_intervals + ], + dim=global_dimension, + ) + + target_area = xr.concat( + [ + clps.sel( + { + "clp_label": penalty.target, + global_dimension: slice(interval[0], interval[1]), + } + ) + for interval in penalty.target_intervals + ], + dim=global_dimension, + ) + + area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) + penalties.append(area_penalty * penalty.weight) + + return np.asarray(penalties) diff --git a/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py b/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py index ad34bf965..4021cb0b5 100644 --- a/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py +++ b/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py @@ -2,6 +2,7 @@ from __future__ import annotations import numpy as np +import xarray as xr from glotaran.model import DatasetDescriptor from glotaran.model import Megacomplex @@ -12,16 +13,17 @@ class KineticBaselineMegacomplex(Megacomplex): def calculate_matrix( self, - model, dataset_model: DatasetDescriptor, indices: dict[str, int], - axis: dict[str, np.ndarray], **kwargs, ): - size = axis[dataset_model.get_model_dimension()].size + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coords()[model_dimension] compartments = [f"{dataset_model.label}_baseline"] - matrix = np.ones((size, 1), dtype=np.float64) - return (compartments, matrix) + matrix = np.ones((model_axis.size, 1), dtype=np.float64) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis.data), ("clp_label", compartments)) + ) def index_dependent(self, dataset: DatasetDescriptor) -> bool: return False diff --git a/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py b/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py index 256395cac..e9995df65 100644 --- a/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py +++ b/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py @@ -5,6 +5,7 @@ import numba as nb import numpy as np +import xarray as xr from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian from glotaran.model import DatasetDescriptor @@ -46,10 +47,8 @@ def index_dependent(self, dataset: DatasetDescriptor) -> bool: def calculate_matrix( self, - model, dataset_model: DatasetDescriptor, indices: dict[str, int], - axis: dict[str, np.ndarray], **kwargs, ): if dataset_model.initial_concentration is None: @@ -72,9 +71,9 @@ def calculate_matrix( global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) - global_axis = axis.get(global_dimension) + global_axis = dataset_model.get_coords().get(global_dimension).values model_dimension = dataset_model.get_model_dimension() - model_axis = axis[model_dimension] + model_axis = dataset_model.get_coords()[model_dimension].values # init the matrix size = (model_axis.size, rates.size) @@ -94,7 +93,9 @@ def calculate_matrix( matrix = matrix @ k_matrix.a_matrix(initial_concentration) # done - return (compartments, matrix) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis), ("clp_label", compartments)) + ) def kinetic_image_matrix_implementation( diff --git a/glotaran/builtin/models/kinetic_image/test/test_baseline.py b/glotaran/builtin/models/kinetic_image/test/test_baseline.py index eb1677d51..c855d1103 100644 --- a/glotaran/builtin/models/kinetic_image/test/test_baseline.py +++ b/glotaran/builtin/models/kinetic_image/test/test_baseline.py @@ -1,4 +1,5 @@ import numpy as np +import xarray as xr from glotaran.analysis.util import calculate_matrix from glotaran.builtin.models.kinetic_image import KineticImageModel @@ -39,13 +40,17 @@ def test_baseline(): ] ) - time = np.asarray(np.arange(0, 50, 1.5)) - dataset = model.dataset["dataset1"].fill(model, parameter) - dataset.overwrite_global_dimension("pixel") - compartments, matrix = calculate_matrix(model, dataset, {}, {"time": time}) + time = xr.DataArray(np.asarray(np.arange(0, 50, 1.5))) + pixel = xr.DataArray([0]) + coords = {"time": time, "pixel": pixel} + dataset_model = model.dataset["dataset1"].fill(model, parameter) + dataset_model.overwrite_global_dimension("pixel") + dataset_model.set_coords(coords) + matrix = calculate_matrix(dataset_model, {}) + compartments = matrix.coords["clp_label"] assert len(compartments) == 2 - assert compartments[1] == "dataset1_baseline" + assert compartments[0] == "dataset1_baseline" assert matrix.shape == (time.size, 2) - assert np.all(matrix[:, 1] == 1) + assert np.all(matrix[:, 0] == 1) diff --git a/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py b/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py index 98d4780d4..5d4e86d61 100644 --- a/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py @@ -3,6 +3,7 @@ import numba as nb import numpy as np +import xarray as xr from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian from glotaran.model import DatasetDescriptor @@ -22,10 +23,8 @@ class CoherentArtifactMegacomplex(Megacomplex): def calculate_matrix( self, - model, dataset_model: DatasetDescriptor, indices: dict[str, int], - axis: dict[str, np.ndarray], **kwargs, ): if not 1 <= self.order <= 3: @@ -39,9 +38,10 @@ def calculate_matrix( global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) - global_axis = axis.get(global_dimension) + global_axis = dataset_model.get_coords().get(global_dimension).values model_dimension = dataset_model.get_model_dimension() - model_axis = axis[model_dimension] + model_axis = dataset_model.get_coords()[model_dimension].values + irf = dataset_model.irf center, width, _, _, _, _ = irf.parameter(global_index, global_axis) @@ -49,7 +49,9 @@ def calculate_matrix( width = self.width.value if self.width is not None else width[0] matrix = _calculate_coherent_artifact_matrix(center, width, model_axis, self.order) - return (self.compartments(), matrix) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis), ("clp_label", self.compartments())) + ) def compartments(self): return [f"coherent_artifact_{i}" for i in range(1, self.order + 1)] diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py b/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py index 681f44346..893f019b0 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py +++ b/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py @@ -51,15 +51,20 @@ def test_coherent_artifact(): ] ) - time = np.asarray(np.arange(0, 50, 1.5)) + time = xr.DataArray(np.arange(0, 50, 1.5)) + spectral = xr.DataArray([0]) + coords = {"time": time, "spectral": spectral} dataset_model = model.dataset["dataset1"].fill(model, parameters) dataset_model.overwrite_global_dimension("spectral") - compartments, matrix = calculate_matrix(model, dataset_model, {}, {"time": time}) + dataset_model.set_coords(coords) + matrix = calculate_matrix(dataset_model, {}) + compartments = matrix.coords["clp_label"].values + print(compartments) assert len(compartments) == 4 for i in range(1, 4): - assert compartments[i] == f"coherent_artifact_{i}" + assert compartments[i - 1] == f"coherent_artifact_{i}" assert matrix.shape == (time.size, 4) diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_constraints.py b/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_constraints.py deleted file mode 100644 index ef572e797..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_constraints.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -import xarray as xr - -from glotaran.analysis.optimize import optimize -from glotaran.analysis.simulation import simulate -from glotaran.analysis.util import calculate_matrix -from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import ( - apply_spectral_constraints, -) -from glotaran.parameter import ParameterGroup -from glotaran.project import Scheme - - -def test_spectral_constraint(): - model = KineticSpectrumModel.from_dict( - { - "initial_concentration": { - "j1": { - "compartments": ["s1", "s2"], - "parameters": ["i.1", "i.2"], - }, - }, - "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s2", "s1"): "kinetic.1", - ("s2", "s2"): "kinetic.2", - } - } - }, - "spectral_constraints": [ - {"type": "zero", "compartment": "s2", "interval": (float("-inf"), float("inf"))}, - ], - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "megacomplex": ["mc1"], - }, - }, - } - ) - print(model) - - wanted_parameters = ParameterGroup.from_dict( - { - "kinetic": [1e-4, 1e-5], - "i": [1, 2], - } - ) - initial_parameters = ParameterGroup.from_dict( - { - "kinetic": [2e-4, 2e-5], - "i": [1, 2, {"vary": False}], - } - ) - - time = np.asarray(np.arange(0, 50, 1.5)) - dataset = model.dataset["dataset1"].fill(model, wanted_parameters) - dataset.overwrite_global_dimension("spectral") - compartments, matrix = calculate_matrix(model, dataset, {}, {"time": time}) - - assert len(compartments) == 2 - assert matrix.shape == (time.size, 2) - - reduced_compartments, reduced_matrix = apply_spectral_constraints( - model, compartments, matrix, 1 - ) - - print(reduced_matrix) - assert len(reduced_compartments) == 1 - assert reduced_matrix.shape == (time.size, 1) - - reduced_compartments, reduced_matrix = model.constrain_matrix_function( - "dataset1", wanted_parameters, compartments, matrix, 1 - ) - - assert reduced_matrix.shape == (time.size, 1) - - clp = xr.DataArray( - [[1.0, 10.0, 20.0, 1]], coords=(("spectral", [1]), ("clp_label", ["s1", "s2", "s3", "s4"])) - ) - - data = simulate( - model, - "dataset1", - wanted_parameters, - clp=clp, - axes={"time": time, "spectral": np.array([1])}, - ) - - dataset = {"dataset1": data} - scheme = Scheme( - model=model, - parameters=initial_parameters, - data=dataset, - maximum_number_function_evaluations=20, - ) - - result = optimize(scheme) - - result_data = result.data["dataset1"] - print(result_data.clp_label) - print(result_data.clp) - # TODO: save reduced clp - # assert result_data.clp.shape == (1, 1) - - print(result_data.species_associated_spectra) - assert result_data.species_associated_spectra.shape == (1, 2) - assert result_data.species_associated_spectra[0, 1] == 0 diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py b/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py index 19bb4453c..99c8b627b 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py +++ b/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py @@ -2,17 +2,14 @@ # To add a new markdown cell, type '# %% [markdown]' # %% import importlib -from collections import deque from collections import namedtuple from copy import deepcopy import numpy as np -import pytest from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import _get_idx_from_interval from glotaran.io import prepare_time_trace_dataset from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -54,25 +51,7 @@ def plot_overview(res, title=None): plt.show(block=False) -@pytest.mark.parametrize("type_factory", [list, deque, tuple, np.array]) -@pytest.mark.parametrize( - "interval,axis,expected", - [ - [(100, 1000), np.linspace(400, 800, 5), (0, 4)], - [(100, 1000), np.linspace(400.0, 800.0, 5), (0, 4)], - [(500, 600), np.linspace(400, 800, 5), (1, 2)], - [(400.0, 800.0), np.linspace(400.0, 800.0, 5), (0, 4)], - [(400.0, np.inf), np.linspace(400.0, 800.0, 5), (0, 4)], - [(0, np.inf), np.linspace(400.0, 800.0, 5), (0, 4)], - [(-np.inf, np.inf), np.linspace(400.0, 800.0, 5), (0, 4)], - ], -) -def test__get_idx_from_interval(type_factory, interval, axis, expected): - axis = type_factory(axis) - assert expected == _get_idx_from_interval(interval, axis) - - -def test_equal_area_penalties(debug=False): +def notest_equal_area_penalties(debug=False): # %% optim_spec = OptimizationSpec(nnls=True, max_nfev=999) @@ -288,9 +267,9 @@ def test_equal_area_penalties(debug=False): assert np.isclose(input_ratio, 1.5038858115) -if __name__ == "__main__": - test__get_idx_from_interval( - type_factory=list, interval=(500, 600), axis=range(400, 800, 100), expected=(1, 2) - ) - test_equal_area_penalties(debug=False) - test_equal_area_penalties(debug=True) +# if __name__ == "__main__": +# test__get_idx_from_interval( +# type_factory=list, interval=(500, 600), axis=range(400, 800, 100), expected=(1, 2) +# ) +# test_equal_area_penalties(debug=False) +# test_equal_area_penalties(debug=True) diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_relations.py b/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_relations.py deleted file mode 100644 index 71979eb1a..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_relations.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -import xarray as xr - -from glotaran.analysis.optimize import optimize -from glotaran.analysis.simulation import simulate -from glotaran.analysis.util import calculate_matrix -from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import ( - create_spectral_relation_matrix, -) -from glotaran.parameter import ParameterGroup -from glotaran.project import Scheme - - -def test_spectral_relation(): - model = KineticSpectrumModel.from_dict( - { - "initial_concentration": { - "j1": { - "compartments": ["s1", "s2", "s3", "s4"], - "parameters": ["i.1", "i.2", "i.3", "i.4"], - }, - }, - "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s1", "s1"): "kinetic.1", - ("s2", "s2"): "kinetic.1", - ("s3", "s3"): "kinetic.1", - ("s4", "s4"): "kinetic.1", - } - } - }, - "spectral_relations": [ - { - "compartment": "s1", - "target": "s2", - "parameter": "rel.1", - "interval": [(0, 2)], - }, - { - "compartment": "s1", - "target": "s3", - "parameter": "rel.2", - "interval": [(0, 2)], - }, - ], - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "megacomplex": ["mc1"], - }, - }, - } - ) - print(model) - - rel1, rel2 = 10, 20 - parameters = ParameterGroup.from_dict( - { - "kinetic": [1e-4], - "i": [1, 2, 3, 4], - "rel": [rel1, rel2], - } - ) - - time = np.asarray(np.arange(0, 50, 1.5)) - dataset = model.dataset["dataset1"].fill(model, parameters) - dataset.overwrite_global_dimension("spectral") - compartments, matrix = calculate_matrix(model, dataset, {}, {"time": time}) - - assert len(compartments) == 4 - assert matrix.shape == (time.size, 4) - - reduced_compartments, relation_matrix = create_spectral_relation_matrix( - model, "dataset1", parameters, compartments, matrix, 1 - ) - - print(relation_matrix) - assert len(reduced_compartments) == 2 - assert relation_matrix.shape == (4, 2) - assert np.array_equal( - relation_matrix, - [ - [1.0, 0.0], - [10.0, 0.0], - [20.0, 0.0], - [0.0, 1.0], - ], - ) - - reduced_compartments, reduced_matrix = model.constrain_matrix_function( - "dataset1", parameters, compartments, matrix, 1 - ) - - assert reduced_matrix.shape == (time.size, 2) - - print(reduced_matrix[0, 0], matrix[0, 0], matrix[0, 1], matrix[0, 2]) - assert np.allclose( - reduced_matrix[:, 0], matrix[:, 0] + rel1 * matrix[:, 1] + rel2 * matrix[:, 2] - ) - - clp = xr.DataArray( - [[1.0, 10.0, 20.0, 1]], coords=(("spectral", [1]), ("clp_label", ["s1", "s2", "s3", "s4"])) - ) - - data = simulate( - model, "dataset1", parameters, clp=clp, axes={"time": time, "spectral": np.array([1])} - ) - - dataset = {"dataset1": data} - scheme = Scheme( - model=model, parameters=parameters, data=dataset, maximum_number_function_evaluations=20 - ) - result = optimize(scheme) - - for label, param in result.optimized_parameters.all(): - if param.vary: - assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) - - result_data = result.data["dataset1"] - print(result_data.species_associated_spectra) - assert result_data.species_associated_spectra.shape == (1, 4) - assert ( - result_data.species_associated_spectra[0, 1] - == rel1 * result_data.species_associated_spectra[0, 0] - ) - assert np.allclose( - result_data.species_associated_spectra[0, 2].values, - rel2 * result_data.species_associated_spectra[0, 0].values, - ) - - -if __name__ == "__main__": - test_spectral_relation() diff --git a/glotaran/builtin/models/spectral/spectral_megacomplex.py b/glotaran/builtin/models/spectral/spectral_megacomplex.py index 02796f3a9..7e3603552 100644 --- a/glotaran/builtin/models/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/models/spectral/spectral_megacomplex.py @@ -3,6 +3,7 @@ from typing import Dict import numpy as np +import xarray as xr from glotaran.model import DatasetDescriptor from glotaran.model import Megacomplex @@ -19,10 +20,8 @@ class SpectralMegacomplex(Megacomplex): def calculate_matrix( self, - model, dataset_model: DatasetDescriptor, indices: dict[str, int], - axis: dict[str, np.ndarray], **kwargs, ): @@ -32,13 +31,18 @@ def calculate_matrix( raise ModelError(f"More then one shape defined for compartment '{compartment}'") compartments.append(compartment) - dim1 = axis[dataset_model.get_model_dimension()].size + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coords()[model_dimension] + + dim1 = model_axis.size dim2 = len(self.shape) matrix = np.zeros((dim1, dim2)) for i, shape in enumerate(self.shape.values()): - matrix[:, i] += shape.calculate(axis[dataset_model.get_model_dimension()]) - return compartments, matrix + matrix[:, i] += shape.calculate(model_axis.values) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis.data), ("clp_label", compartments)) + ) def index_dependent(self, dataset: DatasetDescriptor) -> bool: return False diff --git a/glotaran/builtin/models/spectral/test/test_spectral_model.py b/glotaran/builtin/models/spectral/test/test_spectral_model.py index da3f9b8da..be9456c1f 100644 --- a/glotaran/builtin/models/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/models/spectral/test/test_spectral_model.py @@ -63,16 +63,17 @@ class OneCompartmentModel: spectral_parameters = ParameterGroup.from_list([7, 20000, 800]) - time = np.arange(-10, 50, 1.5) - spectral = np.arange(400, 600, 5) + time = xr.DataArray(np.arange(-10, 50, 1.5)) + spectral = xr.DataArray(np.arange(400, 600, 5)) axis = {"time": time, "spectral": spectral} - dataset = kinetic_model.dataset["dataset1"].fill(kinetic_model, kinetic_parameters) - dataset.overwrite_global_dimension("spectral") - kinetic_compartments, kinetic_matrix = calculate_matrix(kinetic_model, dataset, {}, axis) - clp = xr.DataArray( - kinetic_matrix, coords=[("time", time), ("clp_label", kinetic_compartments)] + kinetic_dataset_model = kinetic_model.dataset["dataset1"].fill( + kinetic_model, kinetic_parameters ) + kinetic_dataset_model.overwrite_global_dimension("spectral") + kinetic_dataset_model.set_coords(axis) + clp = calculate_matrix(kinetic_dataset_model, {}) + kinetic_compartments = clp.coords["clp_label"].values class ThreeCompartmentModel: @@ -161,16 +162,17 @@ class ThreeCompartmentModel: ] ) - time = np.arange(-10, 50, 1.5) - spectral = np.arange(400, 600, 5) + time = xr.DataArray(np.arange(-10, 50, 1.5)) + spectral = xr.DataArray(np.arange(400, 600, 5)) axis = {"time": time, "spectral": spectral} - dataset = kinetic_model.dataset["dataset1"].fill(kinetic_model, kinetic_parameters) - dataset.overwrite_global_dimension("spectral") - kinetic_compartments, kinetic_matrix = calculate_matrix(kinetic_model, dataset, {}, axis) - clp = xr.DataArray( - kinetic_matrix, coords=[("time", time), ("clp_label", kinetic_compartments)] + kinetic_dataset_model = kinetic_model.dataset["dataset1"].fill( + kinetic_model, kinetic_parameters ) + kinetic_dataset_model.overwrite_global_dimension("spectral") + kinetic_dataset_model.set_coords(axis) + clp = calculate_matrix(kinetic_dataset_model, {}) + kinetic_compartments = clp.coords["clp_label"].values @pytest.mark.parametrize( diff --git a/glotaran/model/__init__.py b/glotaran/model/__init__.py index 70af72ebb..00379885e 100644 --- a/glotaran/model/__init__.py +++ b/glotaran/model/__init__.py @@ -7,10 +7,15 @@ from glotaran.model.attribute import model_attribute from glotaran.model.attribute import model_attribute_typed from glotaran.model.base_model import Model +from glotaran.model.clp_penalties import EqualAreaPenalty +from glotaran.model.constraint import Constraint +from glotaran.model.constraint import OnlyConstraint +from glotaran.model.constraint import ZeroConstraint from glotaran.model.dataset_descriptor import DatasetDescriptor from glotaran.model.decorator import model from glotaran.model.megacomplex import Megacomplex from glotaran.model.megacomplex import megacomplex +from glotaran.model.relation import Relation from glotaran.model.util import ModelError from glotaran.model.weight import Weight from glotaran.plugin_system.model_registration import get_model diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py new file mode 100644 index 000000000..6956a7b6e --- /dev/null +++ b/glotaran/model/clp_penalties.py @@ -0,0 +1,161 @@ +"""This package contains compartment constraint items.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import List +from typing import Tuple + +import numpy as np +import xarray as xr + +from glotaran.model import model_attribute +from glotaran.parameter import Parameter + +if TYPE_CHECKING: + from typing import Any + from typing import Sequence + + from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import ( + KineticSpectrumModel, + ) + from glotaran.parameter import ParameterGroup + + +@model_attribute( + properties={ + "source": str, + "source_intervals": List[Tuple[float, float]], + "target": str, + "target_intervals": List[Tuple[float, float]], + "parameter": Parameter, + "weight": str, + }, + no_label=True, +) +class EqualAreaPenalty: + """An equal area constraint adds a the differenc of the sum of a + compartments in the e matrix in one ore more intervals to the scaled sum + of the e matrix of one or more target compartments to residual. The additional + residual is scaled with the weight.""" + + def applies(self, index: Any) -> bool: + """ + Returns true if the index is in one of the intervals. + + Parameters + ---------- + index : + + Returns + ------- + applies : bool + + """ + + def applies(interval): + return interval[0] <= index <= interval[1] + + if isinstance(self.interval, tuple): + return applies(self.interval) + return any([applies(i) for i in self.interval]) + + +def has_spectral_penalties(model: KineticSpectrumModel) -> bool: + return len(model.equal_area_penalties) != 0 + + +def apply_spectral_penalties( + model: KineticSpectrumModel, + parameters: ParameterGroup, + clp_labels: dict[str, list[str] | list[list[str]]], + clps: dict[str, list[np.ndarray]], + matrices: dict[str, np.ndarray | list[np.ndarray]], + data: dict[str, xr.Dataset], + group_tolerance: float, +) -> np.ndarray: + + penalties = [] + for penalty in model.equal_area_penalties: + + penalty = penalty.fill(model, parameters) + source_area = _get_area( + model.index_dependent(), + model.global_dimension, + clp_labels, + clps, + data, + group_tolerance, + penalty.source_intervals, + penalty.source, + ) + + target_area = _get_area( + model.index_dependent(), + model.global_dimension, + clp_labels, + clps, + data, + group_tolerance, + penalty.target_intervals, + penalty.target, + ) + + area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) + penalties.append(area_penalty * penalty.weight) + return np.asarray(penalties) + + +def _get_area( + index_dependent: bool, + global_dimension: str, + clp_labels: dict[str, list[list[str]]], + clps: dict[str, list[np.ndarray]], + data: dict[str, xr.Dataset], + group_tolerance: float, + intervals: list[tuple[float, float]], + compartment: str, +) -> np.ndarray: + area = [] + area_indices = [] + + for label, dataset in data.items(): + global_axis = dataset.coords[global_dimension] + for interval in intervals: + if interval[0] > global_axis[-1]: + # interval not in this dataset + continue + + start_idx, end_idx = _get_idx_from_interval(interval, global_axis) + for i in range(start_idx, end_idx + 1): + index_clp_labels = clp_labels[label][i] if index_dependent else clp_labels[label] + if compartment in index_clp_labels: + area.append(clps[label][i][index_clp_labels.index(compartment)]) + area_indices.append(global_axis[i]) + + return np.asarray(area) # TODO: normalize for distance on global axis + + +def _get_idx_from_interval( + interval: tuple[float, float], axis: Sequence[float] | np.ndarray +) -> tuple[int, int]: + """Retrieves start and end index of an interval on some axis + + Parameters + ---------- + interval : A tuple of floats with begin and end of the interval + axis : Array like object which can be cast to np.array + + Returns + ------- + start, end : tuple of int + + """ + axis_array = np.array(axis) + start = np.abs(axis_array - interval[0]).argmin() if not np.isinf(interval[0]) else 0 + end = ( + np.abs(axis_array - interval[1]).argmin() + if not np.isinf(interval[1]) + else axis_array.size - 1 + ) + return start, end diff --git a/glotaran/model/constraint.py b/glotaran/model/constraint.py new file mode 100644 index 000000000..d6683a85f --- /dev/null +++ b/glotaran/model/constraint.py @@ -0,0 +1,66 @@ +"""This package contains compartment constraint items.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from glotaran.model import model_attribute +from glotaran.model import model_attribute_typed +from glotaran.model.interval_property import IntervalProperty + +if TYPE_CHECKING: + from typing import Any + + +@model_attribute( + properties={ + "target": str, + }, + has_type=True, + no_label=True, +) +class OnlyConstraint(IntervalProperty): + """A only constraint sets the calculated matrix row of a compartment to 0 + outside the given intervals.""" + + def applies(self, index: Any) -> bool: + """ + Returns true if the indexx is in one of the intervals. + + Parameters + ---------- + index : + + Returns + ------- + applies : bool + + """ + return not super().applies(index) + + +@model_attribute( + properties={ + "target": str, + }, + has_type=True, + no_label=True, +) +class ZeroConstraint(IntervalProperty): + """A zero constraint sets the calculated matrix row of a compartment to 0 + in the given intervals.""" + + +@model_attribute_typed( + types={ + "only": OnlyConstraint, + "zero": ZeroConstraint, + }, + no_label=True, +) +class Constraint: + """A constraint is applied on one clp on one or many + intervals on the estimated axis type. + + There are two types: zero and equal. See the documentation of + the respective classes for details. + """ diff --git a/glotaran/model/dataset_descriptor.py b/glotaran/model/dataset_descriptor.py index 789d15698..93c086959 100644 --- a/glotaran/model/dataset_descriptor.py +++ b/glotaran/model/dataset_descriptor.py @@ -73,11 +73,22 @@ def set_data(self, data: xr.Dataset) -> DatasetDescriptor: self._data = data return self + def get_data(self) -> xr.Dataset: + return self._data + def index_dependent(self) -> bool: if hasattr(self, "_index_dependent"): return self._index_dependent return any(m.index_dependent(self) for m in self.megacomplex) + def set_coords(self, coords: xr.Dataset): + self._coords = coords + + def get_coords(self) -> xr.Dataset: + if hasattr(self, "_coords"): + return self._coords + return self._data.coords + @deprecate( deprecated_qual_name_usage=( "glotaran.model.dataset_descriptor.DatasetDescriptor.overwrite_index_dependent" diff --git a/glotaran/model/decorator.py b/glotaran/model/decorator.py index 23e30e356..347dec4b5 100644 --- a/glotaran/model/decorator.py +++ b/glotaran/model/decorator.py @@ -7,8 +7,11 @@ from glotaran.deprecation import deprecate from glotaran.model.attribute import model_attribute_typed +from glotaran.model.clp_penalties import EqualAreaPenalty +from glotaran.model.constraint import Constraint from glotaran.model.dataset_descriptor import DatasetDescriptor from glotaran.model.megacomplex import Megacomplex +from glotaran.model.relation import Relation from glotaran.model.util import wrap_func_as_method from glotaran.model.weight import Weight from glotaran.plugin_system.model_registration import register_model @@ -195,6 +198,9 @@ def decorator(cls): attributes["dataset"] = dataset_type attributes["megacomplex"] = megacomplex_cls attributes["weights"] = Weight + attributes["relations"] = Relation + attributes["constraints"] = Constraint + attributes["clp_area_penalties"] = EqualAreaPenalty # Set annotations and methods for attributes for attr_name, attr_type in attributes.items(): diff --git a/glotaran/model/interval_property.py b/glotaran/model/interval_property.py new file mode 100644 index 000000000..e1d3c6888 --- /dev/null +++ b/glotaran/model/interval_property.py @@ -0,0 +1,44 @@ +"""Helper functions.""" +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Tuple + +from glotaran.model import model_attribute + + +@model_attribute( + properties={ + "interval": {"type": List[Tuple[Any, Any]], "default": None, "allow_none": True}, + }, + no_label=True, +) +class IntervalProperty: + """Applies a relation between clps as + + :math:`source = parameter * target`. + """ + + def applies(self, index: Any) -> bool: + """ + Returns true if the index is in one of the intervals. + + Parameters + ---------- + index : + + Returns + ------- + applies : bool + + """ + if self.interval is None: + return True + + def applies(interval): + return interval[0] <= index <= interval[1] + + if isinstance(self.interval, tuple): + return applies(self.interval) + return not any([applies(i) for i in self.interval]) diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index eb9dcb4b2..906598bf8 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -import numpy as np +import xarray as xr from glotaran.model import DatasetDescriptor from glotaran.model import model_attribute @@ -37,12 +37,10 @@ class Megacomplex: def calculate_matrix( self, - model, - dataset_descriptor: DatasetDescriptor, + dataset_model: DatasetDescriptor, indices: dict[str, int], - axis: dict[str, np.ndarray], **kwargs, - ): + ) -> xr.DataArray: raise NotImplementedError def index_dependent(self, dataset: DatasetDescriptor) -> bool: diff --git a/glotaran/model/relation.py b/glotaran/model/relation.py new file mode 100644 index 000000000..b5c210d09 --- /dev/null +++ b/glotaran/model/relation.py @@ -0,0 +1,21 @@ +""" Glotaran Relation """ +from __future__ import annotations + +from glotaran.model import model_attribute +from glotaran.model.interval_property import IntervalProperty +from glotaran.parameter import Parameter + + +@model_attribute( + properties={ + "source": str, + "target": str, + "parameter": Parameter, + }, + no_label=True, +) +class Relation(IntervalProperty): + """Applies a relation between clps as + + :math:`target = parameter * source`. + """ diff --git a/tox.ini b/tox.ini index b769ec467..6fe9c651c 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,10 @@ envlist = py{38}, pre-commit, docs, docs-notebooks, docs-links ; Uncomment the following lines to deactivate pyglotaran all plugins ; env = ; DEACTIVATE_GTA_PLUGINS=1 +; Uncomment to ignore deprecation warnings coming from pyglotaran +; (this helps to see the warnings from dependencies) +; filterwarnings = +; ignore:.+glotaran:DeprecationWarning [flake8] extend-ignore = E231, E203 From a61a39d738b5e27e52357bcb8a046f648b310986 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 4 Jul 2021 19:43:26 +0200 Subject: [PATCH 03/29] =?UTF-8?q?=F0=9F=94=A7Add=20more=20QA=20tools=20for?= =?UTF-8?q?=20parts=20of=20glotaran=20(#739)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 Allowed mypy to run on a subset of glotaran modules that will be checked are: - utils - plugin_system - deprecation * 🔧 Partially activated docstring QA tools and fixed issues modules that will be checked are: - utils - plugin_system - deprecation --- .pre-commit-config.yaml | 46 ++++++++++++------- glotaran/builtin/io/yml/test/__init__.py | 0 .../builtin/models/spectral/test/__init__.py | 0 glotaran/deprecation/deprecation_utils.py | 8 ++-- glotaran/model/test/__init__.py | 0 .../plugin_system/data_io_registration.py | 2 +- glotaran/plugin_system/model_registration.py | 2 +- .../plugin_system/project_io_registration.py | 2 +- glotaran/utils/ipython.py | 10 ++-- glotaran/utils/test/test_ipython.py | 5 +- setup.cfg | 14 +++++- 11 files changed, 59 insertions(+), 30 deletions(-) create mode 100644 glotaran/builtin/io/yml/test/__init__.py create mode 100644 glotaran/builtin/models/spectral/test/__init__.py create mode 100644 glotaran/model/test/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9d09b6cf..1491a1924 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,34 +57,48 @@ repos: args: [--strip-empty-cells] - repo: https://github.com/nbQA-dev/nbQA - rev: 0.13.0 + rev: 0.13.1 hooks: - id: nbqa-black - additional_dependencies: [black==20.8b1] + additional_dependencies: [black==21.6b0] args: [--nbqa-mutate] - id: nbqa-pyupgrade - additional_dependencies: [pyupgrade==2.9.0] + additional_dependencies: [pyupgrade==2.19.4] args: [--nbqa-mutate, --py38-plus] - id: nbqa-flake8 - id: nbqa-check-ast - id: nbqa-isort - additional_dependencies: [isort==5.7.0] + additional_dependencies: [isort==5.9.1] args: [--nbqa-mutate] # Linters - # - repo: https://github.com/PyCQA/pydocstyle - # rev: 5.1.1 - # hooks: - # - id: pydocstyle - # exclude: "docs|tests" - # # this is needed due to the following issue: - # # https://github.com/PyCQA/pydocstyle/issues/368 - # args: [--ignore-decorators=wrap_func_as_method] - # - repo: https://github.com/terrencepreilly/darglint - # rev: v1.5.5 - # hooks: - # - id: darglint + - repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle + files: "^glotaran/(plugin_system|utils|deprecation)" + exclude: "docs|tests?" + # this is needed due to the following issue: + # https://github.com/PyCQA/pydocstyle/issues/368 + args: [--ignore-decorators=wrap_func_as_method] + + - repo: https://github.com/terrencepreilly/darglint + rev: v1.8.0 + hooks: + - id: darglint + files: "^glotaran/(plugin_system|utils|deprecation)" + exclude: "docs|tests?" + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.910 + hooks: + - id: mypy + files: "^glotaran/(plugin_system|utils|deprecation)" + exclude: "docs" + args: [glotaran] + pass_filenames: false + additional_dependencies: [types-all] - repo: https://github.com/econchick/interrogate rev: 1.4.0 diff --git a/glotaran/builtin/io/yml/test/__init__.py b/glotaran/builtin/io/yml/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glotaran/builtin/models/spectral/test/__init__.py b/glotaran/builtin/models/spectral/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glotaran/deprecation/deprecation_utils.py b/glotaran/deprecation/deprecation_utils.py index 08af03bbd..8f9c18db1 100644 --- a/glotaran/deprecation/deprecation_utils.py +++ b/glotaran/deprecation/deprecation_utils.py @@ -321,7 +321,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> DecoratedCallable: def outer_wrapper(deprecated_object: DecoratedCallable) -> DecoratedCallable: """Wrap deprecated_object of all callable kinds.""" - if type(deprecated_object) is not type: + if not isinstance(deprecated_object, type): return cast(DecoratedCallable, inject_warn_into_call(deprecated_object)) setattr( @@ -329,9 +329,9 @@ def outer_wrapper(deprecated_object: DecoratedCallable) -> DecoratedCallable: "__new__", inject_warn_into_call(deprecated_object.__new__), # type: ignore [arg-type] ) - return deprecated_object + return deprecated_object # type: ignore[return-value] - return outer_wrapper + return cast(Callable[[DecoratedCallable], DecoratedCallable], outer_wrapper) def module_attribute(module_qual_name: str, attribute_name: str) -> Any: @@ -488,7 +488,7 @@ def deprecate_submodule( def warn_getattr(attribute_name: str): if attribute_name == "__file__": - return module_attribute(new_module, attribute_name) + return new_module.__file__ elif attribute_name in dir(new_module): return deprecate_module_attribute( diff --git a/glotaran/model/test/__init__.py b/glotaran/model/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glotaran/plugin_system/data_io_registration.py b/glotaran/plugin_system/data_io_registration.py index 8f51686bd..732002254 100644 --- a/glotaran/plugin_system/data_io_registration.py +++ b/glotaran/plugin_system/data_io_registration.py @@ -302,7 +302,7 @@ def data_io_plugin_table(*, plugin_names: bool = False, full_names: bool = False Returns ------- - str + MarkdownStr Markdown table of data io plugins. """ table_data = methods_differ_from_baseclass_table( diff --git a/glotaran/plugin_system/model_registration.py b/glotaran/plugin_system/model_registration.py index 62bb91ea2..f8493292b 100644 --- a/glotaran/plugin_system/model_registration.py +++ b/glotaran/plugin_system/model_registration.py @@ -134,7 +134,7 @@ def model_plugin_table(*, plugin_names: bool = False, full_names: bool = False) Returns ------- - str + MarkdownStr Markdown table of modelnames. """ table_data = [] diff --git a/glotaran/plugin_system/project_io_registration.py b/glotaran/plugin_system/project_io_registration.py index b5f90545f..b2cb10110 100644 --- a/glotaran/plugin_system/project_io_registration.py +++ b/glotaran/plugin_system/project_io_registration.py @@ -488,7 +488,7 @@ def project_io_plugin_table( Returns ------- - str + MarkdownStr Markdown table of project io plugins. """ table_data = methods_differ_from_baseclass_table( diff --git a/glotaran/utils/ipython.py b/glotaran/utils/ipython.py index 9fefddb0b..d4e29cc38 100644 --- a/glotaran/utils/ipython.py +++ b/glotaran/utils/ipython.py @@ -9,8 +9,7 @@ class MarkdownStr(UserString): """String wrapper class for rich display integration of markdown in ipython.""" def __init__(self, wrapped_str: str, *, syntax: str = None): - """String class automatically displayed as markdown by ipython. - + """Initialize string class that is automatically displayed as markdown by ``ipython``. Parameters ---------- @@ -23,13 +22,16 @@ def __init__(self, wrapped_str: str, *, syntax: str = None): ---- Possible syntax highlighting values can e.g. be found here: https://support.codebasehq.com/articles/tips-tricks/syntax-highlighting-in-markdown + + + .. # noqa: DAR101 """ # This needs to be called data since ipython is looking for this attr self.data = str(wrapped_str) self.syntax = syntax def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown. + """Render markdown automatically when in a ``ipython`` context. See: https://ipython.readthedocs.io/en/latest/config/integrating.html?highlight=_repr_markdown_#rich-display @@ -64,7 +66,7 @@ def display_file(path: str | PathLike[str], *, syntax: str = None) -> MarkdownSt ---------- path : str | PathLike[str] Paths to the file - syntax : str, optional + syntax : str Syntax highlighting which should be applied, by default None Returns diff --git a/glotaran/utils/test/test_ipython.py b/glotaran/utils/test/test_ipython.py index 621aaf334..d8e574bf4 100644 --- a/glotaran/utils/test/test_ipython.py +++ b/glotaran/utils/test/test_ipython.py @@ -39,5 +39,6 @@ def test_display_file(tmp_path: Path): expected = MarkdownStr(file_content, syntax="yaml") tmp_file = tmp_path / "test.yml" tmp_file.write_text(file_content) - for path in (tmp_file, str(tmp_file)): - assert display_file(path, syntax="yaml") == expected + + assert display_file(tmp_file, syntax="yaml") == expected + assert display_file(str(tmp_file), syntax="yaml") == expected diff --git a/setup.cfg b/setup.cfg index c7c3d7c9f..d90957a3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,7 +75,7 @@ ignore_messages = xarraydoc [darglint] docstring_style = numpy -ignore_regex = test_.+|.*wrapper.*|inject_warn_into_call|.*dummy.* +ignore_regex = test_.+|.*wrapper.*|inject_warn_into_call|.*dummy.*|__(str|eq)__ [pydocstyle] convention = numpy @@ -84,3 +84,15 @@ convention = numpy ignore_missing_imports = True scripts_are_modules = True show_error_codes = True + +[mypy-glotaran.*] +ignore_errors = True + +[mypy-glotaran.plugin_system.*] +ignore_errors = False + +[mypy-glotaran.utils.*] +ignore_errors = False + +[mypy-glotaran.deprecation.*] +ignore_errors = False From a2539b5616e00aeaefe994c682c235cdcaeaa698 Mon Sep 17 00:00:00 2001 From: s-weigand Date: Fri, 9 Jul 2021 22:26:32 +0200 Subject: [PATCH 04/29] =?UTF-8?q?=E2=9C=A8=20Feature:=20Megacomplex=20Mode?= =?UTF-8?q?ls=20(#736)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes models like kinetic-image in favor of independent megacomplex model. Therefor, there are a lot of internal changes. Key Changes for users: * use `default-megacomplex` instead of `model-type` in yml spec * simulations can now be done by setting `global_megacomplex` in a dataset model * heterogenous dataset models can be analyzed in one optimization, e.g. a spectral and a temporal dataset * the user now needs to set `group=true` in a scheme to have a grouped analysis. --- .github/CODEOWNERS | 54 +- glotaran/analysis/optimize.py | 2 +- glotaran/analysis/problem.py | 51 +- glotaran/analysis/problem_grouped.py | 37 +- glotaran/analysis/problem_ungrouped.py | 25 +- glotaran/analysis/simulation.py | 148 +++--- glotaran/analysis/test/models.py | 412 +++++++--------- glotaran/analysis/test/test_constraints.py | 2 +- glotaran/analysis/test/test_grouping.py | 66 +-- glotaran/analysis/test/test_optimization.py | 37 +- glotaran/analysis/test/test_penalties.py | 6 +- glotaran/analysis/test/test_problem.py | 49 +- glotaran/analysis/test/test_relations.py | 2 +- glotaran/analysis/test/test_simulation.py | 51 +- glotaran/analysis/util.py | 28 +- glotaran/builtin/io/yml/sanatize.py | 75 +++ ...parser_kinetic.py => test_model_parser.py} | 85 ++-- ...l_spec_kinetic.yml => test_model_spec.yml} | 42 +- glotaran/builtin/io/yml/yml.py | 43 +- .../test => megacomplexes}/__init__.py | 0 .../megacomplexes/baseline/__init__.py | 1 + .../baseline/baseline_megacomplex.py | 31 ++ .../test/test_baseline_megacomplex.py} | 15 +- .../coherent_artifact/__init__.py | 3 + .../coherent_artifact_megacomplex.py | 33 +- .../test/test_coherent_artifact.py | 18 +- .../builtin/megacomplexes/decay/__init__.py | 1 + .../megacomplexes/decay/decay_megacomplex.py | 131 +++++ .../decay}/initial_concentration.py | 12 +- .../decay}/irf.py | 103 +++- .../decay}/k_matrix.py | 13 +- .../decay/test/test_decay_megacomplex.py} | 99 ++-- .../decay}/test/test_k_matrix.py | 13 +- .../decay}/test/test_spectral_irf.py | 26 +- glotaran/builtin/megacomplexes/decay/util.py | 220 +++++++++ .../megacomplexes/spectral/__init__.py | 1 + .../spectral/shape.py | 16 +- .../spectral/spectral_megacomplex.py | 80 +++ .../spectral/test/test_spectral_model.py | 52 +- glotaran/builtin/models/__init__.py | 1 - .../builtin/models/kinetic_image/__init__.py | 1 - .../kinetic_image/initial_concentration.pyi | 14 - glotaran/builtin/models/kinetic_image/irf.pyi | 31 -- .../builtin/models/kinetic_image/k_matrix.pyi | 31 -- .../kinetic_baseline_megacomplex.py | 29 -- .../kinetic_decay_megacomplex.py | 189 ------- .../kinetic_image_dataset_descriptor.py | 14 - .../kinetic_image_dataset_descriptor.pyi | 14 - .../kinetic_image/kinetic_image_model.py | 47 -- .../kinetic_image/kinetic_image_model.pyi | 34 -- .../kinetic_image/kinetic_image_result.py | 153 ------ .../models/kinetic_spectrum/__init__.py | 1 - .../kinetic_spectrum_dataset_descriptor.py | 15 - .../kinetic_spectrum_dataset_descriptor.pyi | 10 - .../kinetic_spectrum_model.py | 130 ----- .../kinetic_spectrum_model.pyi | 93 ---- .../kinetic_spectrum_result.py | 82 ---- .../kinetic_spectrum/spectral_constraints.py | 115 ----- .../kinetic_spectrum/spectral_constraints.pyi | 29 -- .../models/kinetic_spectrum/spectral_irf.py | 96 ---- .../models/kinetic_spectrum/spectral_irf.pyi | 32 -- .../kinetic_spectrum/spectral_matrix.py | 34 -- .../kinetic_spectrum/spectral_penalties.py | 161 ------ .../kinetic_spectrum/spectral_penalties.pyi | 32 -- .../kinetic_spectrum/spectral_relations.py | 125 ----- .../kinetic_spectrum/spectral_relations.pyi | 43 -- .../models/kinetic_spectrum/spectral_shape.py | 89 ---- .../kinetic_spectrum/spectral_shape.pyi | 26 - glotaran/builtin/models/spectral/__init__.py | 1 - .../models/spectral/spectral_megacomplex.py | 48 -- .../builtin/models/spectral/spectral_model.py | 28 -- .../models/spectral/spectral_result.py | 42 -- .../modules/test/test_glotaran_root.py | 14 +- .../modules/test/test_model_base_model.py | 30 -- .../test/test_model_dataset_deescriptor.py | 10 - .../modules/test/test_project_sheme.py | 11 +- glotaran/examples/sequential.py | 43 +- glotaran/model/__init__.py | 19 +- glotaran/model/base_model.py | 283 ----------- glotaran/model/base_model.pyi | 69 --- glotaran/model/clp_penalties.py | 6 +- glotaran/model/constraint.py | 16 +- glotaran/model/dataset_descriptor.py | 102 ---- glotaran/model/dataset_descriptor.pyi | 25 - glotaran/model/dataset_model.py | 162 ++++++ glotaran/model/decorator.py | 460 ------------------ glotaran/model/interval_property.py | 6 +- glotaran/model/{attribute.py => item.py} | 101 ++-- glotaran/model/megacomplex.py | 110 ++++- glotaran/model/model.py | 332 +++++++++++++ glotaran/model/property.py | 10 +- glotaran/model/relation.py | 6 +- glotaran/model/test/test_model.py | 304 ++++++++---- glotaran/model/weight.py | 6 +- glotaran/plugin_system/base_registry.py | 12 +- .../plugin_system/megacomplex_registration.py | 157 ++++++ glotaran/plugin_system/model_registration.py | 150 ------ .../plugin_system/test/test_base_registry.py | 14 +- .../test/test_megacomplex_registration.py | 138 ++++++ .../test/test_model_registration.py | 150 ------ glotaran/project/scheme.py | 1 + glotaran/project/test/test_result.py | 14 +- glotaran/project/test/test_scheme.py | 13 +- .../kinetic_spectrum => }/test/__init__.py | 0 .../test_spectral_decay.py} | 43 +- .../test/test_spectral_penalties.py | 68 ++- requirements_dev.txt | 1 + setup.cfg | 10 +- 108 files changed, 2727 insertions(+), 4042 deletions(-) rename glotaran/builtin/io/yml/test/{test_model_parser_kinetic.py => test_model_parser.py} (63%) rename glotaran/builtin/io/yml/test/{test_model_spec_kinetic.yml => test_model_spec.yml} (72%) rename glotaran/builtin/{models/kinetic_image/test => megacomplexes}/__init__.py (100%) create mode 100644 glotaran/builtin/megacomplexes/baseline/__init__.py create mode 100644 glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py rename glotaran/builtin/{models/kinetic_image/test/test_baseline.py => megacomplexes/baseline/test/test_baseline_megacomplex.py} (75%) create mode 100644 glotaran/builtin/megacomplexes/coherent_artifact/__init__.py rename glotaran/builtin/{models/kinetic_spectrum => megacomplexes/coherent_artifact}/coherent_artifact_megacomplex.py (64%) rename glotaran/builtin/{models/kinetic_spectrum => megacomplexes/coherent_artifact}/test/test_coherent_artifact.py (85%) create mode 100644 glotaran/builtin/megacomplexes/decay/__init__.py create mode 100644 glotaran/builtin/megacomplexes/decay/decay_megacomplex.py rename glotaran/builtin/{models/kinetic_image => megacomplexes/decay}/initial_concentration.py (74%) rename glotaran/builtin/{models/kinetic_image => megacomplexes/decay}/irf.py (53%) rename glotaran/builtin/{models/kinetic_image => megacomplexes/decay}/k_matrix.py (96%) rename glotaran/builtin/{models/kinetic_image/test/test_kinetic_image_model.py => megacomplexes/decay/test/test_decay_megacomplex.py} (78%) rename glotaran/builtin/{models/kinetic_image => megacomplexes/decay}/test/test_k_matrix.py (93%) rename glotaran/builtin/{models/kinetic_spectrum => megacomplexes/decay}/test/test_spectral_irf.py (88%) create mode 100644 glotaran/builtin/megacomplexes/decay/util.py create mode 100644 glotaran/builtin/megacomplexes/spectral/__init__.py rename glotaran/builtin/{models => megacomplexes}/spectral/shape.py (93%) create mode 100644 glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py rename glotaran/builtin/{models => megacomplexes}/spectral/test/test_spectral_model.py (81%) delete mode 100644 glotaran/builtin/models/__init__.py delete mode 100644 glotaran/builtin/models/kinetic_image/__init__.py delete mode 100644 glotaran/builtin/models/kinetic_image/initial_concentration.pyi delete mode 100644 glotaran/builtin/models/kinetic_image/irf.pyi delete mode 100644 glotaran/builtin/models/kinetic_image/k_matrix.pyi delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.py delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.pyi delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_image_model.py delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_image_model.pyi delete mode 100644 glotaran/builtin/models/kinetic_image/kinetic_image_result.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/__init__.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_result.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_constraints.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_constraints.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_irf.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_irf.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_matrix.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_penalties.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_penalties.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_relations.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_relations.pyi delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_shape.py delete mode 100644 glotaran/builtin/models/kinetic_spectrum/spectral_shape.pyi delete mode 100644 glotaran/builtin/models/spectral/__init__.py delete mode 100644 glotaran/builtin/models/spectral/spectral_megacomplex.py delete mode 100644 glotaran/builtin/models/spectral/spectral_model.py delete mode 100644 glotaran/builtin/models/spectral/spectral_result.py delete mode 100644 glotaran/deprecation/modules/test/test_model_base_model.py delete mode 100644 glotaran/deprecation/modules/test/test_model_dataset_deescriptor.py delete mode 100644 glotaran/model/base_model.py delete mode 100644 glotaran/model/base_model.pyi delete mode 100644 glotaran/model/dataset_descriptor.py delete mode 100644 glotaran/model/dataset_descriptor.pyi create mode 100644 glotaran/model/dataset_model.py delete mode 100644 glotaran/model/decorator.py rename glotaran/model/{attribute.py => item.py} (78%) create mode 100644 glotaran/model/model.py create mode 100644 glotaran/plugin_system/megacomplex_registration.py delete mode 100644 glotaran/plugin_system/model_registration.py create mode 100644 glotaran/plugin_system/test/test_megacomplex_registration.py delete mode 100644 glotaran/plugin_system/test/test_model_registration.py rename glotaran/{builtin/models/kinetic_spectrum => }/test/__init__.py (100%) rename glotaran/{builtin/models/kinetic_spectrum/test/test_kinetic_spectrum_model.py => test/test_spectral_decay.py} (91%) rename glotaran/{builtin/models/kinetic_spectrum => }/test/test_spectral_penalties.py (81%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f11743572..5aa4a20fb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -12,49 +12,55 @@ # This should make it easy to add new rules without breaking existing ones. # Global rule: -* @glotaran/admins +* @glotaran/admins # tooling -/.github @glotaran/admins @glotaran/maintainers -/.* @glotaran/admins @glotaran/maintainers -/*.y*ml @glotaran/admins @glotaran/maintainers -/*.ini @glotaran/admins @glotaran/maintainers -/*.toml @glotaran/admins @glotaran/maintainers -/*.txt @glotaran/admins @glotaran/maintainers -LICENSE @glotaran/pyglotaran_creators +/.github @glotaran/admins @glotaran/maintainers +/.* @glotaran/admins @glotaran/maintainers +/*.y*ml @glotaran/admins @glotaran/maintainers +/*.ini @glotaran/admins @glotaran/maintainers +/*.toml @glotaran/admins @glotaran/maintainers +/*.txt @glotaran/admins @glotaran/maintainers +LICENSE @glotaran/pyglotaran_creators # docs -/docs/**/*.rst @glotaran/maintainers @glotaran/pyglotaran_creators -# /docs/**/*.md @glotaran/maintainers @glotaran/pyglotaran_creators +/docs/**/*.rst @glotaran/maintainers @glotaran/pyglotaran_creators +# /docs/**/*.md @glotaran/maintainers @glotaran/pyglotaran_creators # analysis module: -/glotaran/analysis/ @jsnel @joernweissenborn +/glotaran/analysis/ @jsnel @joernweissenborn # builtin module: -/glotaran/builtin/io/* @glotaran/admins -/glotaran/builtin/io/ascii @jsnel @glotaran/maintainers -/glotaran/builtin/io/csv @glotaran/maintainers -/glotaran/builtin/io/netCDF @glotaran/maintainers -/glotaran/builtin/io/sdt @glotaran/maintainers -/glotaran/builtin/models/ @jsnel @joernweissenborn +/glotaran/builtin/io/* @glotaran/admins +/glotaran/builtin/io/ascii @jsnel @glotaran/maintainers +/glotaran/builtin/io/csv @glotaran/maintainers +/glotaran/builtin/io/netCDF @glotaran/maintainers +/glotaran/builtin/io/sdt @glotaran/maintainers +/glotaran/builtin/megacomplexes/ @jsnel @joernweissenborn # cli -/glotaran/cli/ @jsnel @glotaran/admins +/glotaran/cli/ @jsnel @glotaran/admins # examples -/glotaran/examples/ @jsnel @glotaran/maintainers +/glotaran/examples/ @jsnel @glotaran/maintainers # io -/glotaran/io/ @jsnel @glotaran/maintainers +/glotaran/io/ @jsnel @glotaran/maintainers # model -/glotaran/model/ @jsnel @glotaran/admins @joernweissenborn +/glotaran/model/ @jsnel @glotaran/admins @joernweissenborn # parameter -/glotaran/parameter/ @jsnel @glotaran/admins @joernweissenborn +/glotaran/parameter/ @jsnel @glotaran/admins @joernweissenborn # plugin_system -glotaran/plugin_system @s-weigand @glotaran/admins +glotaran/plugin_system @s-weigand @glotaran/admins + +# deprecation framework and tests +glotaran/deprecation @s-weigand @glotaran/admins + +# utility function +glotaran/utils @s-weigand @glotaran/admins # project -/glotaran/project/ @jsnel @glotaran/admins +/glotaran/project/ @jsnel @glotaran/admins diff --git a/glotaran/analysis/optimize.py b/glotaran/analysis/optimize.py index 606beb73a..a53dbd5be 100644 --- a/glotaran/analysis/optimize.py +++ b/glotaran/analysis/optimize.py @@ -20,7 +20,7 @@ def optimize(scheme: Scheme, verbose: bool = True) -> Result: - problem = GroupedProblem(scheme) if scheme.model.grouped() else UngroupedProblem(scheme) + problem = GroupedProblem(scheme) if scheme.group else UngroupedProblem(scheme) return optimize_problem(problem, verbose=verbose) diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index 71cfe8f2f..320dfc9ef 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -14,7 +14,7 @@ from glotaran.analysis.util import get_min_max_from_interval from glotaran.analysis.variable_projection import residual_variable_projection from glotaran.io.prepare_dataset import add_svd_to_dataset -from glotaran.model import DatasetDescriptor +from glotaran.model import DatasetModel from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -29,7 +29,7 @@ def __init__(self): class UngroupedProblemDescriptor(NamedTuple): - dataset: DatasetDescriptor + dataset: DatasetModel data: xr.DataArray model_axis: np.ndarray global_axis: np.ndarray @@ -75,19 +75,20 @@ def __init__(self, scheme: Scheme): self._model = scheme.model - self._grouped = scheme.model.grouped() self._bag = None - self._groups = None self._residual_function = ( residual_nnls if scheme.non_negative_least_squares else residual_variable_projection ) self._parameters = None - self._filled_dataset_descriptors = None + self._dataset_models = None - self._overwrite_index_dependent = hasattr(scheme.model, "overwrite_index_dependent") + self._overwrite_index_dependent = self.model.need_index_dependent() self._parameters = scheme.parameters.copy() self._parameter_history = [] + + self._model.validate(raise_exception=True) + self._prepare_data(scheme.data) # all of the above are always not None @@ -145,12 +146,8 @@ def parameter_history(self) -> list[ParameterGroup]: return self._parameter_history @property - def grouped(self) -> bool: - return self._grouped - - @property - def filled_dataset_descriptors(self) -> dict[str, DatasetDescriptor]: - return self._filled_dataset_descriptors + def dataset_models(self) -> dict[str, DatasetModel]: + return self._dataset_models @property def bag(self) -> UngroupedBag | GroupedBag: @@ -158,12 +155,6 @@ def bag(self) -> UngroupedBag | GroupedBag: self.init_bag() return self._bag - @property - def groups(self) -> dict[str, list[str]]: - if not self._groups and self._grouped: - self.init_bag() - return self._groups - @property def matrices( self, @@ -232,14 +223,14 @@ def save_parameters_for_history(self): self._parameter_history.append(self._parameters) def reset(self): - """Resets all results and `DatasetDescriptors`. Use after updating parameters.""" - self._filled_dataset_descriptors = { + """Resets all results and `DatasetModels`. Use after updating parameters.""" + self._dataset_models = { label: dataset_model.fill(self._model, self._parameters).set_data(self.data[label]) for label, dataset_model in self._model.dataset.items() } if self._overwrite_index_dependent: - for d in self._filled_dataset_descriptors.values(): - d.overwrite_index_dependent(self.model.overwrite_index_dependent()) + for d in self._dataset_models.values(): + d.overwrite_index_dependent(self._overwrite_index_dependent) self._reset_results() def _reset_results(self): @@ -254,7 +245,7 @@ def _reset_results(self): def _prepare_data(self, data: dict[str, xr.DataArray | xr.Dataset]): self._data = {} - self._filled_dataset_descriptors = {} + self._dataset_models = {} for label, dataset in data.items(): if isinstance(dataset, xr.DataArray): dataset = dataset.to_dataset(name="data") @@ -263,8 +254,8 @@ def _prepare_data(self, data: dict[str, xr.DataArray | xr.Dataset]): dataset_model = dataset_model.fill(self.model, self.parameters) dataset_model.set_data(dataset) if self._overwrite_index_dependent: - dataset_model.overwrite_index_dependent(self.model.overwrite_index_dependent()) - self._filled_dataset_descriptors[label] = dataset_model + dataset_model.overwrite_index_dependent(self._overwrite_index_dependent) + self._dataset_models[label] = dataset_model global_dimension = dataset_model.get_global_dimension() model_dimension = dataset_model.get_model_dimension() @@ -309,7 +300,7 @@ def _add_weight(self, label, dataset): " because weight is already supplied by dataset." ) return - dataset_model = self._filled_dataset_descriptors[label] + dataset_model = self.dataset_models[label] dataset_model.set_data(dataset) global_dimension = dataset_model.get_global_dimension() model_dimension = dataset_model.get_model_dimension() @@ -348,15 +339,15 @@ def create_result_data( if history_index is not None and history_index != -1: self.parameters = self.parameter_history[history_index] result_data = {label: self.create_result_dataset(label, copy=copy) for label in self.data} - - if callable(self.model.finalize_data): - self.model.finalize_data(self, result_data) + for label, dataset_model in self.dataset_models.items(): + result_data[label] = self.create_result_dataset(label, copy=copy) + dataset_model.finalize_data(result_data[label]) return result_data def create_result_dataset(self, label: str, copy: bool = True) -> xr.Dataset: dataset = self.data[label] - dataset_model = self._filled_dataset_descriptors[label] + dataset_model = self.dataset_models[label] global_dimension = dataset_model.get_global_dimension() model_dimension = dataset_model.get_model_dimension() if copy: diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index 1708f0841..7a93717b4 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -17,7 +17,7 @@ from glotaran.analysis.util import find_overlap from glotaran.analysis.util import reduce_matrix from glotaran.analysis.util import retrieve_clps -from glotaran.model import DatasetDescriptor +from glotaran.model import DatasetModel from glotaran.project import Scheme @@ -34,12 +34,8 @@ def __init__(self, scheme: Scheme): super().__init__(scheme=scheme) # TODO: grouping should be user controlled not inferred automatically - global_dimensions = { - d.get_global_dimension() for d in self.filled_dataset_descriptors.values() - } - model_dimensions = { - d.get_model_dimension() for d in self.filled_dataset_descriptors.values() - } + global_dimensions = {d.get_global_dimension() for d in self.dataset_models.values()} + model_dimensions = {d.get_model_dimension() for d in self.dataset_models.values()} if len(global_dimensions) != 1: raise ValueError( f"Cannot group datasets. Global dimensions '{global_dimensions}' do not match." @@ -48,12 +44,11 @@ def __init__(self, scheme: Scheme): raise ValueError( f"Cannot group datasets. Model dimension '{model_dimensions}' do not match." ) - self._index_dependent = any( - d.index_dependent() for d in self.filled_dataset_descriptors.values() - ) + self._index_dependent = any(d.index_dependent() for d in self.dataset_models.values()) self._global_dimension = global_dimensions.pop() self._model_dimension = model_dimensions.pop() self._group_clp_labels = None + self._groups = None def init_bag(self): """Initializes a grouped problem bag.""" @@ -181,6 +176,12 @@ def _append_to_grouped_bag( self._full_axis.append(global_axis[i]) self._bag.append(problem) + @property + def groups(self) -> dict[str, list[str]]: + if not self._groups: + self.init_bag() + return self._groups + def calculate_matrices(self): if self._parameters is None: raise ParameterError @@ -195,7 +196,7 @@ def calculate_index_dependent_matrices( """Calculates the index dependent model matrices.""" def calculate_group( - group: ProblemGroup, descriptors: dict[str, DatasetDescriptor] + group: ProblemGroup, descriptors: dict[str, DatasetModel] ) -> tuple[list[xr.DataArray], xr.DataArray, xr.DataArray]: matrices = [ calculate_matrix( @@ -213,9 +214,7 @@ def calculate_group( ) return matrices, group_clp_labels, reduced_matrix - results = list( - map(lambda group: calculate_group(group, self._filled_dataset_descriptors), self._bag) - ) + results = list(map(lambda group: calculate_group(group, self.dataset_models), self._bag)) matrices = list(map(lambda result: result[0], results)) @@ -238,7 +237,7 @@ def calculate_index_independent_matrices( self._group_clp_labels = {} self._reduced_matrices = {} - for label, dataset_model in self._filled_dataset_descriptors.items(): + for label, dataset_model in self.dataset_models.items(): self._matrices[label] = calculate_matrix( dataset_model, {}, @@ -310,10 +309,10 @@ def _index_dependent_residual( if problem.has_scaling: for i, descriptor in enumerate(problem.descriptor): label = descriptor.label - if self.filled_dataset_descriptors[label] is not None: + if self.dataset_models[label] is not None: start = sum(problem.data_sizes[0:i]) end = start + problem.data_sizes[i] - matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale + matrix[start:end, :] *= self.dataset_models[label].scale reduced_clps, residual = self._residual_function(matrix.values, data) reduced_clps = xr.DataArray( @@ -336,10 +335,10 @@ def _index_independent_residual(self, problem: ProblemGroup, index: any): if problem.has_scaling: for i, descriptor in enumerate(problem.descriptor): label = descriptor.label - if self.filled_dataset_descriptors[label] is not None: + if self.dataset_models[label] is not None: start = sum(problem.data_sizes[0:i]) end = start + problem.data_sizes[i] - matrix[start:end, :] *= self.filled_dataset_descriptors[label].scale + matrix[start:end, :] *= self.dataset_models[label].scale reduced_clps, residual = self._residual_function(matrix.values, data) reduced_clps = xr.DataArray( reduced_clps, dims=["clp_label"], coords={"clp_label": matrix.coords["clp_label"]} diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 38a8d522b..e5f030ef5 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -10,7 +10,7 @@ from glotaran.analysis.util import calculate_matrix from glotaran.analysis.util import reduce_matrix from glotaran.analysis.util import retrieve_clps -from glotaran.model import DatasetDescriptor +from glotaran.model import DatasetModel class UngroupedProblem(Problem): @@ -19,7 +19,7 @@ class UngroupedProblem(Problem): def init_bag(self): """Initializes an ungrouped problem bag.""" self._bag = {} - for label, dataset_model in self.filled_dataset_descriptors.items(): + for label, dataset_model in self.dataset_models.items(): dataset = self._data[label] data = dataset.data weight = dataset.weight if "weight" in dataset else None @@ -48,7 +48,7 @@ def calculate_matrices( self._reduced_matrices = {} for label, problem in self.bag.items(): - dataset_model = self._filled_dataset_descriptors[label] + dataset_model = self.dataset_models[label] if dataset_model.index_dependent(): self._calculate_index_dependent_matrix(label, problem, dataset_model) @@ -58,7 +58,7 @@ def calculate_matrices( return self._matrices, self._reduced_matrices def _calculate_index_dependent_matrix( - self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetDescriptor + self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetModel ): self._matrices[label] = [] self._reduced_matrices[label] = [] @@ -74,7 +74,7 @@ def _calculate_index_dependent_matrix( self._reduced_matrices[label].append(reduced_matrix) def _calculate_index_independent_matrix( - self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetDescriptor + self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetModel ): matrix = calculate_matrix( @@ -117,7 +117,7 @@ def _calculate_residual_for_problem(self, label: str, problem: UngroupedProblemD self._weighted_residuals[label] = [] self._residuals[label] = [] data = problem.data - dataset_model = self._filled_dataset_descriptors[label] + dataset_model = self.dataset_models[label] model_dimension = dataset_model.get_model_dimension() global_dimension = dataset_model.get_global_dimension() @@ -133,7 +133,7 @@ def _calculate_residual_for_problem(self, label: str, problem: UngroupedProblemD else self.reduced_matrices[label] ) if problem.dataset.scale is not None: - reduced_matrix *= self.filled_dataset_descriptors[label].scale + reduced_matrix *= self.dataset_models[label].scale if problem.weight is not None: for j in range(reduced_matrix.shape[1]): @@ -195,8 +195,8 @@ def create_index_independent_result_dataset( return dataset def _add_index_dependent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): - model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() - global_dimension = self.filled_dataset_descriptors[label].get_global_dimension() + model_dimension = self.dataset_models[label].get_model_dimension() + global_dimension = self.dataset_models[label].get_global_dimension() matrix = xr.concat(self.matrices[label], dim=global_dimension) matrix.coords[global_dimension] = dataset.coords[global_dimension] @@ -212,7 +212,7 @@ def _add_index_dependent_matrix_to_dataset(self, label: str, dataset: xr.Dataset def _add_index_independent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): dataset.coords["clp_label"] = self.matrices[label].coords["clp_label"] - model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() + model_dimension = self.dataset_models[label].get_model_dimension() dataset["matrix"] = ( ( (model_dimension), @@ -222,8 +222,8 @@ def _add_index_independent_matrix_to_dataset(self, label: str, dataset: xr.Datas ) def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset): - model_dimension = self.filled_dataset_descriptors[label].get_model_dimension() - global_dimension = self.filled_dataset_descriptors[label].get_global_dimension() + model_dimension = self.dataset_models[label].get_model_dimension() + global_dimension = self.dataset_models[label].get_global_dimension() dataset["clp"] = self.clps[label] dataset["weighted_residual"] = ( ( @@ -245,7 +245,6 @@ def full_penalty(self) -> np.ndarray: if self._full_penalty is None: residuals = self.weighted_residuals additional_penalty = self.additional_penalty - print(residuals) residuals = [np.concatenate(residuals[label]) for label in residuals.keys()] self._full_penalty = ( diff --git a/glotaran/analysis/simulation.py b/glotaran/analysis/simulation.py index c8efb3492..964557251 100644 --- a/glotaran/analysis/simulation.py +++ b/glotaran/analysis/simulation.py @@ -7,6 +7,7 @@ import xarray as xr from glotaran.analysis.util import calculate_matrix +from glotaran.model import DatasetModel if TYPE_CHECKING: from glotaran.model import Model @@ -17,11 +18,11 @@ def simulate( model: Model, dataset: str, parameters: ParameterGroup, - axes: dict[str, np.ndarray] = None, - clp: np.ndarray | xr.DataArray = None, - noise=False, - noise_std_dev=1.0, - noise_seed=None, + coordinates: dict[str, np.ndarray], + clp: xr.DataArray | None = None, + noise: bool = False, + noise_std_dev: float = 1.0, + noise_seed: int | None = None, ): """Simulates a model. @@ -45,85 +46,96 @@ def simulate( The seed for the noise simulation. """ - if model.global_matrix is None and clp is None: + dataset_model = model.dataset[dataset].fill(model, parameters) + dataset_model.set_coordinates(coordinates) + + if dataset_model.global_model(): + result = simulate_global_model( + dataset_model, + parameters, + clp, + ) + elif clp is None: raise ValueError( - "Cannot simulate models without implementation for global matrix and no clp given." + f"Cannot simulate dataset {dataset} without global megacomplex " "and no clp provided." + ) + else: + result = simulate_clp( + dataset_model, + parameters, + clp, ) - filled_dataset = model.dataset[dataset].fill(model, parameters) - filled_dataset.overwrite_global_dimension(model.global_dimension) - if hasattr(model, "overwrite_index_dependent"): - filled_dataset.overwrite_index_dependent(model.overwrite_index_dependent()) + if noise: + if noise_seed is not None: + np.random.seed(noise_seed) + result["data"] = (result.data.dims, np.random.normal(result.data, noise_std_dev)) - model_dimension = filled_dataset.get_model_dimension() - model_axis = axes[model_dimension] - global_dimension = filled_dataset.get_global_dimension() - global_axis = axes[global_dimension] + return result - result = xr.DataArray( - data=0.0, - coords=[ - (model_dimension, model_axis.data), - (global_dimension, global_axis.data), - ], - ) - result = result.to_dataset(name="data") - filled_dataset.set_data(result) - matrix = ( +def simulate_clp( + dataset_model: DatasetModel, + parameters: ParameterGroup, + clp: xr.DataArray, +): + + if "clp_label" not in clp.coords: + raise ValueError("Missing coordinate 'clp_label' in clp.") + global_dimension = next(dim for dim in clp.coords if dim != "clp_label") + + global_axis = clp.coords[global_dimension] + matrices = ( [ calculate_matrix( - filled_dataset, + dataset_model, {global_dimension: index}, ) for index, _ in enumerate(global_axis) ] - if filled_dataset.index_dependent() - else calculate_matrix( - filled_dataset, - {}, - ) + if dataset_model.index_dependent() + else calculate_matrix(dataset_model, {}) ) - if clp is not None: - if clp.shape[0] != global_axis.size: - raise ValueError( - f"Size of dimension 0 of clp ({clp.shape[0]}) != size of axis" - f" '{global_dimension}' ({global_axis.size})" - ) - if isinstance(clp, xr.DataArray): - if global_dimension not in clp.coords: - raise ValueError(f"Missing coordinate '{global_dimension}' in clp.") - if "clp_label" not in clp.coords: - raise ValueError("Missing coordinate 'clp_label' in clp.") - elif "clp_label" not in axes: - raise ValueError("Missing axis 'clp_label'") - else: - clp = xr.DataArray( - clp, - coords=[ - (global_dimension, global_axis), - ("clp_label", axes["clp_label"]), - ], - ) - else: - clp_labels, clp = model.global_matrix(filled_dataset, global_axis) - clp = xr.DataArray( - clp, coords=[(global_dimension, global_axis), ("clp_label", clp_labels)] - ) + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coordinates()[model_dimension] + result = xr.DataArray( + data=0.0, + coords=[ + (model_dimension, model_axis.data), + (global_dimension, global_axis.data), + ], + ) + result = result.to_dataset(name="data") for i in range(global_axis.size): - index_matrix = matrix[i] if filled_dataset.index_dependent() else matrix - print(index_matrix.coords) + index_matrix = matrices[i] if dataset_model.index_dependent() else matrices result.data[:, i] = np.dot( - index_matrix, clp[i].sel(clp_label=index_matrix.coords["clp_label"]) - ) - - if noise: - if noise_seed is not None: - np.random.seed(noise_seed) - result["data"] = ( - (model_dimension, global_dimension), - np.random.normal(result.data, noise_std_dev), + index_matrix, + clp.isel({global_dimension: i}).sel({"clp_label": index_matrix.coords["clp_label"]}), ) return result + + +def simulate_global_model( + dataset_model: DatasetModel, + parameters: ParameterGroup, + clp: xr.DataArray = None, +): + """Simulates a global model.""" + + # TODO: implement full model clp + if clp is not None: + raise NotImplementedError("Simulation of full models with clp is not supported yet.") + + if any(m.index_dependent(dataset_model) for m in dataset_model.global_megacomplex): + raise ValueError("Index dependent models for global dimension are not supported.") + + global_matrix = calculate_matrix(dataset_model, {}, global_model=True) + global_matrix = global_matrix.T + + return simulate_clp( + dataset_model, + parameters, + global_matrix, + ) diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 5d9d409b1..2d5cc9d81 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -5,36 +5,42 @@ import numpy as np import xarray as xr -from glotaran.model import DatasetDescriptor from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.model import megacomplex -from glotaran.model import model -from glotaran.model import model_attribute from glotaran.parameter import Parameter from glotaran.parameter import ParameterGroup -def calculate_e(dataset, axis): - compartments = ["s1", "s2"] - r_compartments = [] - array = np.zeros((axis.shape[0], len(compartments))) +@megacomplex(dimension="global", properties={}) +class SimpleTestMegacomplexGlobal(Megacomplex): + def calculate_matrix(self, dataset_model, indices, **kwargs): + axis = dataset_model.get_coordinates() + assert "model" in axis + assert "global" in axis + axis = axis["global"] + compartments = ["s1", "s2"] + r_compartments = [] + array = np.zeros((axis.shape[0], len(compartments))) + + for i in range(len(compartments)): + r_compartments.append(compartments[i]) + for j in range(axis.shape[0]): + array[j, i] = (i + j) * axis[j] + return xr.DataArray(array, coords=(("global", axis.data), ("clp_label", r_compartments))) - for i in range(len(compartments)): - r_compartments.append(compartments[i]) - for j in range(axis.shape[0]): - array[j, i] = (i + j) * axis[j] - return (r_compartments, array) + def index_dependent(self, dataset_model): + return False -@megacomplex("c", properties={"is_index_dependent": bool}) +@megacomplex(dimension="model", properties={"is_index_dependent": bool}) class SimpleTestMegacomplex(Megacomplex): def calculate_matrix(self, dataset_model, indices, **kwargs): - axis = dataset_model.get_data().coords - assert "c" in axis - assert "e" in axis + axis = dataset_model.get_coordinates() + assert "model" in axis + assert "global" in axis - axis = axis["c"] + axis = axis["model"] compartments = ["s1", "s2"] r_compartments = [] array = np.zeros((axis.shape[0], len(compartments))) @@ -43,31 +49,35 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): r_compartments.append(compartments[i]) for j in range(axis.shape[0]): array[j, i] = (i + j) * axis[j] - return xr.DataArray(array, coords=(("c", axis.data), ("clp_label", r_compartments))) + return xr.DataArray(array, coords=(("model", axis.data), ("clp_label", r_compartments))) def index_dependent(self, dataset_model): return self.is_index_dependent -@model( - "simple_test", - attributes={}, - model_dimension="c", - global_matrix=calculate_e, - global_dimension="e", - megacomplex_types=SimpleTestMegacomplex, -) class SimpleTestModel(Model): - pass + @classmethod + def from_dict(cls, model_dict): + return super().from_dict( + model_dict, + megacomplex_types={ + "model_complex": SimpleTestMegacomplex, + "global_complex": SimpleTestMegacomplexGlobal, + }, + ) -@megacomplex("c", properties={"is_index_dependent": bool}) +@megacomplex( + dimension="model", + properties={"is_index_dependent": bool}, + dataset_properties={ + "kinetic": List[Parameter], + }, +) class SimpleKineticMegacomplex(Megacomplex): def calculate_matrix(self, dataset_model, indices, **kwargs): - axis = dataset_model.get_data().coords - assert "c" in axis - assert "e" in axis - axis = axis["c"] + axis = dataset_model.get_coordinates() + axis = axis["model"] kinpar = -1 * np.asarray(dataset_model.kinetic) if dataset_model.label == "dataset3": # this case is for the ThreeDatasetDecay test @@ -75,199 +85,73 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): else: compartments = [f"s{i+1}" for i in range(len(kinpar))] array = np.exp(np.outer(axis, kinpar)) - return xr.DataArray(array, coords=(("c", axis.data), ("clp_label", compartments))) + return xr.DataArray(array, coords=(("model", axis.data), ("clp_label", compartments))) def index_dependent(self, dataset_model): return self.is_index_dependent + def finalize_data(self, dataset_model, data): + pass -def calculate_spectral_simple(dataset_descriptor, axis): - kinpar = -1 * np.array(dataset_descriptor.kinetic) - if dataset_descriptor.label == "dataset3": - # this case is for the ThreeDatasetDecay test - compartments = [f"s{i+2}" for i in range(len(kinpar))] - else: - compartments = [f"s{i+1}" for i in range(len(kinpar))] - array = np.asarray([[1 for _ in range(axis.size)] for _ in compartments]) - return compartments, array.T - - -def calculate_spectral_gauss(dataset, axis): - location = np.asarray(dataset.location) - amp = np.asarray(dataset.amplitude) - delta = np.asarray(dataset.delta) - - array = np.empty((location.size, axis.size), dtype=np.float64) - - for i in range(location.size): - array[i, :] = amp[i] * np.exp(-np.log(2) * np.square(2 * (axis - location[i]) / delta[i])) - compartments = [f"s{i+1}" for i in range(location.size)] - return compartments, array.T - - -def constrain_matrix_function_typecheck( - model: type[Model], - label: str, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -): - assert isinstance(label, str) - assert isinstance(parameters, ParameterGroup) - assert isinstance(clp_labels, list) - assert all(isinstance(clp_label, str) for clp_label in clp_labels) - assert isinstance(matrix, np.ndarray) - - model.constrain_matrix_function_called = True - - return (clp_labels, matrix) - - -def retrieve_clp_typecheck( - model: type[Model], - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - reduced_clp_labels: dict[str, list[str] | list[list[str]]], - reduced_clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], -) -> dict[str, list[np.ndarray]]: - assert isinstance(parameters, ParameterGroup) - - assert isinstance(reduced_clps, dict) - assert all(isinstance(dataset_clps, list) for dataset_clps in reduced_clps.values()) - - assert all( - [isinstance(index_clps, np.ndarray) for index_clps in dataset_clps] - for dataset_clps in reduced_clps.values() - ) - - assert isinstance(data, dict) - assert all(isinstance(label, str) for label in data) - assert all(isinstance(dataset, xr.Dataset) for dataset in data.values()) - - assert isinstance(clp_labels, dict) - assert isinstance(reduced_clp_labels, dict) - assert all( - isinstance(dataset_clp_labels, list) for dataset_clp_labels in reduced_clp_labels.values() - ) - assert all( - [[isinstance(label, str) for label in index_labels] for index_labels in dataset_clp_labels] - for dataset_clp_labels in reduced_clp_labels.values() - ) - model.retrieve_clp_function_called = True +@megacomplex(dimension="global", properties={}) +class SimpleSpectralMegacomplex(Megacomplex): + def calculate_matrix(self, dataset_model, indices, **kwargs): + axis = dataset_model.get_coordinates() + axis = axis["global"] + kinpar = dataset_model.kinetic + if dataset_model.label == "dataset3": + # this case is for the ThreeDatasetDecay test + compartments = [f"s{i+2}" for i in range(len(kinpar))] + else: + compartments = [f"s{i+1}" for i in range(len(kinpar))] + array = np.asarray([[1 for _ in range(axis.size)] for _ in compartments]).T + return xr.DataArray(array, coords=(("global", axis.data), ("clp_label", compartments))) - return reduced_clps + def index_dependent(self, dataset_model): + return False -def additional_penalty_typecheck( - model: type[Model], - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - clps: dict[str, list[np.ndarray]], - matrices: dict[str, np.ndarray | list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, -) -> np.ndarray: - assert isinstance(parameters, ParameterGroup) - assert isinstance(group_tolerance, float) +@megacomplex( + dimension="global", + properties={ + "location": {"type": List[Parameter], "allow_none": True}, + "amplitude": {"type": List[Parameter], "allow_none": True}, + "delta": {"type": List[Parameter], "allow_none": True}, + }, +) +class ShapedSpectralMegacomplex(Megacomplex): + def calculate_matrix(self, dataset_model, indices, **kwargs): + location = np.asarray(self.location) + amp = np.asarray(self.amplitude) + delta = np.asarray(self.delta) - assert isinstance(clps, dict) - assert all(isinstance(dataset_clps, list) for dataset_clps in clps.values()) - assert all( - [isinstance(index_clps, np.ndarray) for index_clps in dataset_clps] - for dataset_clps in clps.values() - ) + axis = dataset_model.get_coordinates() + axis = axis["global"] + array = np.empty((location.size, axis.size), dtype=np.float64) - assert isinstance(data, dict) - assert all(isinstance(label, str) for label in data) - assert all(isinstance(dataset, xr.Dataset) for dataset in data.values()) - - assert isinstance(clp_labels, dict) - assert isinstance(matrices, dict) - if model.megacomplex["m1"].index_dependent(model.dataset["dataset1"]): - for dataset_clp_labels in clp_labels.values(): - assert all(isinstance(index_label, list) for index_label in dataset_clp_labels) - assert all( - [isinstance(label, str) for label in index_label] - for index_label in dataset_clp_labels + for i in range(location.size): + array[i, :] = amp[i] * np.exp( + -np.log(2) * np.square(2 * (axis - location[i]) / delta[i]) ) + compartments = [f"s{i+1}" for i in range(location.size)] + return xr.DataArray(array.T, coords=(("global", axis.data), ("clp_label", compartments))) - for matrix in matrices.values(): - assert isinstance(matrix, list) - assert all(isinstance(index_matrix, np.ndarray) for index_matrix in matrix) - else: - for dataset_clp_labels in clp_labels.values(): - assert all(isinstance(label, str) for label in dataset_clp_labels) - for matrix in matrices.values(): - assert isinstance(matrix, np.ndarray) - - model.additional_penalty_function_called = True - - return np.asarray([0.1]) - - -@model_attribute( - properties={ - "kinetic": List[Parameter], - } -) -class DecayDatasetDescriptor(DatasetDescriptor): - pass + def index_dependent(self, dataset_model): + return False -@model_attribute( - properties={ - "kinetic": List[Parameter], - "location": {"type": List[Parameter], "allow_none": True}, - "amplitude": {"type": List[Parameter], "allow_none": True}, - "delta": {"type": List[Parameter], "allow_none": True}, - } -) -class GaussianShapeDecayDatasetDescriptor(DatasetDescriptor): - pass - - -@model( - "one_channel", - attributes={}, - dataset_type=DecayDatasetDescriptor, - model_dimension="c", - global_matrix=calculate_spectral_simple, - global_dimension="e", - megacomplex_types=SimpleKineticMegacomplex, - # has_additional_penalty_function=lambda model: True, - # additional_penalty_function=additional_penalty_typecheck, - # has_matrix_constraints_function=lambda model: True, - # constrain_matrix_function=constrain_matrix_function_typecheck, - # retrieve_clp_function=retrieve_clp_typecheck, - grouped=lambda model: model.is_grouped, -) class DecayModel(Model): - additional_penalty_function_called = False - constrain_matrix_function_called = False - retrieve_clp_function_called = False - is_grouped = False - - -@model( - "multi_channel", - attributes={}, - dataset_type=GaussianShapeDecayDatasetDescriptor, - model_dimension="c", - global_matrix=calculate_spectral_gauss, - global_dimension="e", - megacomplex_types=SimpleKineticMegacomplex, - grouped=lambda model: model.is_grouped, - # has_additional_penalty_function=lambda model: True, - # additional_penalty_function=additional_penalty_typecheck, -) -class GaussianDecayModel(Model): - additional_penalty_function_called = False - constrain_matrix_function_called = False - retrieve_clp_function_called = False - is_grouped = False + @classmethod + def from_dict(cls, model_dict): + return super().from_dict( + model_dict, + megacomplex_types={ + "model_complex": SimpleKineticMegacomplex, + "global_complex": SimpleSpectralMegacomplex, + "global_complex_shaped": ShapedSpectralMegacomplex, + }, + ) class OneCompartmentDecay: @@ -275,16 +159,32 @@ class OneCompartmentDecay: wanted_parameters = ParameterGroup.from_list([101e-4]) initial_parameters = ParameterGroup.from_list([100e-5, [scale, {"vary": False}]]) - e_axis = np.asarray([1.0]) - c_axis = np.arange(0, 150, 1.5) + global_axis = np.asarray([1.0]) + model_axis = np.arange(0, 150, 1.5) + sim_model_dict = { + "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "dataset": { + "dataset1": { + "initial_concentration": [], + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["1"], + } + }, + } + sim_model = DecayModel.from_dict(sim_model_dict) model_dict = { "megacomplex": {"m1": {"is_index_dependent": False}}, "dataset": { - "dataset1": {"initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1"]} + "dataset1": { + "initial_concentration": [], + "megacomplex": ["m1"], + "kinetic": ["1"], + "scale": "2", + } }, } - sim_model = DecayModel.from_dict(model_dict) model_dict["dataset"]["dataset1"]["scale"] = "2" model = DecayModel.from_dict(model_dict) @@ -293,9 +193,22 @@ class TwoCompartmentDecay: wanted_parameters = ParameterGroup.from_list([11e-4, 22e-5]) initial_parameters = ParameterGroup.from_list([10e-4, 20e-5]) - e_axis = np.asarray([1.0]) - c_axis = np.arange(0, 150, 1.5) + global_axis = np.asarray([1.0]) + model_axis = np.arange(0, 150, 1.5) + sim_model = DecayModel.from_dict( + { + "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "dataset": { + "dataset1": { + "initial_concentration": [], + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["1", "2"], + } + }, + } + ) model = DecayModel.from_dict( { "megacomplex": {"m1": {"is_index_dependent": False}}, @@ -308,21 +221,45 @@ class TwoCompartmentDecay: }, } ) - sim_model = model class ThreeDatasetDecay: wanted_parameters = ParameterGroup.from_list([101e-4, 201e-3]) initial_parameters = ParameterGroup.from_list([100e-5, 200e-3]) - e_axis = np.asarray([1.0]) - c_axis = np.arange(0, 150, 1.5) + global_axis = np.asarray([1.0]) + model_axis = np.arange(0, 150, 1.5) - e_axis2 = np.asarray([1.0, 2.01]) - c_axis2 = np.arange(0, 100, 1.5) + global_axis2 = np.asarray([1.0, 2.01]) + model_axis2 = np.arange(0, 100, 1.5) - e_axis3 = np.asarray([0.99, 3.0]) - c_axis3 = np.arange(0, 150, 1.5) + global_axis3 = np.asarray([0.99, 3.0]) + model_axis3 = np.arange(0, 150, 1.5) + + sim_model_dict = { + "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "dataset": { + "dataset1": { + "initial_concentration": [], + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["1"], + }, + "dataset2": { + "initial_concentration": [], + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["1", "2"], + }, + "dataset3": { + "initial_concentration": [], + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["2"], + }, + }, + } + sim_model = DecayModel.from_dict(sim_model_dict) model_dict = { "megacomplex": {"m1": {"is_index_dependent": False}}, @@ -336,8 +273,7 @@ class ThreeDatasetDecay: "dataset3": {"initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["2"]}, }, } - sim_model = DecayModel.from_dict(model_dict) - model = sim_model + model = DecayModel.from_dict(model_dict) class MultichannelMulticomponentDecay: @@ -366,32 +302,36 @@ class MultichannelMulticomponentDecay: ) initial_parameters = ParameterGroup.from_dict({"k": [0.006, 0.003, 0.0003, 0.03]}) - e_axis = np.arange(12820, 15120, 50) - c_axis = np.arange(0, 150, 1.5) + global_axis = np.arange(12820, 15120, 50) + model_axis = np.arange(0, 150, 1.5) - sim_model = GaussianDecayModel.from_dict( + sim_model = DecayModel.from_dict( { - "compartment": ["s1", "s2", "s3", "s4"], - "megacomplex": {"m1": {"is_index_dependent": False}}, + # "compartment": ["s1", "s2", "s3", "s4"], + "megacomplex": { + "m1": {"is_index_dependent": False}, + "m2": { + "type": "global_complex_shaped", + "location": ["loc.1", "loc.2", "loc.3", "loc.4"], + "delta": ["del.1", "del.2", "del.3", "del.4"], + "amplitude": ["amp.1", "amp.2", "amp.3", "amp.4"], + }, + }, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], + "global_megacomplex": ["m2"], "kinetic": ["k.1", "k.2", "k.3", "k.4"], - "location": ["loc.1", "loc.2", "loc.3", "loc.4"], - "delta": ["del.1", "del.2", "del.3", "del.4"], - "amplitude": ["amp.1", "amp.2", "amp.3", "amp.4"], } }, } ) - model = GaussianDecayModel.from_dict( + model = DecayModel.from_dict( { - "compartment": ["s1", "s2", "s3", "s4"], + # "compartment": ["s1", "s2", "s3", "s4"], "megacomplex": {"m1": {"is_index_dependent": False}}, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["k.1", "k.2", "k.3", "k.4"], } diff --git a/glotaran/analysis/test/test_constraints.py b/glotaran/analysis/test/test_constraints.py index c83eaf88a..ee43af9da 100644 --- a/glotaran/analysis/test/test_constraints.py +++ b/glotaran/analysis/test/test_constraints.py @@ -22,7 +22,7 @@ def test_constraint(index_dependent, grouped): suite.sim_model, "dataset1", suite.wanted_parameters, - {"e": suite.e_axis, "c": suite.c_axis}, + {"global": suite.global_axis, "model": suite.model_axis}, ) scheme = Scheme(model=model, parameters=suite.initial_parameters, data={"dataset1": dataset}) problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) diff --git a/glotaran/analysis/test/test_grouping.py b/glotaran/analysis/test/test_grouping.py index 9600dc51e..35ea15f42 100644 --- a/glotaran/analysis/test/test_grouping.py +++ b/glotaran/analysis/test/test_grouping.py @@ -26,12 +26,12 @@ def test_single_dataset(): parameters = ParameterGroup.from_list([1, 10]) print(model.validate(parameters)) assert model.valid(parameters) - axis_e = [1, 2, 3] - axis_c = [5, 7, 9, 12] + global_axis = [1, 2, 3] + model_axis = [5, 7, 9, 12] data = { "dataset1": xr.DataArray( - np.ones((3, 4)), coords=[("e", axis_e), ("c", axis_c)] + np.ones((3, 4)), coords=[("global", global_axis), ("model", model_axis)] ).to_dataset(name="data") } @@ -43,9 +43,9 @@ def test_single_dataset(): assert len(bag) == 3 assert all(p.data.size == 4 for p in bag) assert all(p.descriptor[0].label == "dataset1" for p in bag) - assert all(all(p.descriptor[0].axis["c"] == axis_c) for p in bag) - assert all(all(p.descriptor[0].axis["e"] == axis_e) for p in bag) - assert [p.descriptor[0].indices["e"] for p in bag] == [0, 1, 2] + assert all(all(p.descriptor[0].axis["model"] == model_axis) for p in bag) + assert all(all(p.descriptor[0].axis["global"] == global_axis) for p in bag) + assert [p.descriptor[0].indices["global"] for p in bag] == [0, 1, 2] def test_multi_dataset_no_overlap(): @@ -72,16 +72,16 @@ def test_multi_dataset_no_overlap(): print(model.validate(parameters)) assert model.valid(parameters) - axis_e_1 = [1, 2, 3] - axis_c_1 = [5, 7] - axis_e_2 = [4, 5, 6] - axis_c_2 = [5, 7, 9] + global_axis_1 = [1, 2, 3] + model_axis_1 = [5, 7] + global_axis_2 = [4, 5, 6] + model_axis_2 = [5, 7, 9] data = { "dataset1": xr.DataArray( - np.ones((3, 2)), coords=[("e", axis_e_1), ("c", axis_c_1)] + np.ones((3, 2)), coords=[("global", global_axis_1), ("model", model_axis_1)] ).to_dataset(name="data"), "dataset2": xr.DataArray( - np.ones((3, 3)), coords=[("e", axis_e_2), ("c", axis_c_2)] + np.ones((3, 3)), coords=[("global", global_axis_2), ("model", model_axis_2)] ).to_dataset(name="data"), } @@ -92,15 +92,15 @@ def test_multi_dataset_no_overlap(): assert len(bag) == 6 assert all(p.data.size == 2 for p in bag[:3]) assert all(p.descriptor[0].label == "dataset1" for p in bag[:3]) - assert all(all(p.descriptor[0].axis["c"] == axis_c_1) for p in bag[:3]) - assert all(all(p.descriptor[0].axis["e"] == axis_e_1) for p in bag[:3]) - assert [p.descriptor[0].indices["e"] for p in bag[:3]] == [0, 1, 2] + assert all(all(p.descriptor[0].axis["model"] == model_axis_1) for p in bag[:3]) + assert all(all(p.descriptor[0].axis["global"] == global_axis_1) for p in bag[:3]) + assert [p.descriptor[0].indices["global"] for p in bag[:3]] == [0, 1, 2] assert all(p.data.size == 3 for p in bag[3:]) assert all(p.descriptor[0].label == "dataset2" for p in bag[3:]) - assert all(all(p.descriptor[0].axis["c"] == axis_c_2) for p in bag[3:]) - assert all(all(p.descriptor[0].axis["e"] == axis_e_2) for p in bag[3:]) - assert [p.descriptor[0].indices["e"] for p in bag[3:]] == [0, 1, 2] + assert all(all(p.descriptor[0].axis["model"] == model_axis_2) for p in bag[3:]) + assert all(all(p.descriptor[0].axis["global"] == global_axis_2) for p in bag[3:]) + assert [p.descriptor[0].indices["global"] for p in bag[3:]] == [0, 1, 2] def test_multi_dataset_overlap(): @@ -127,16 +127,16 @@ def test_multi_dataset_overlap(): print(model.validate(parameters)) assert model.valid(parameters) - axis_e_1 = [1, 2, 3, 5] - axis_c_1 = [5, 7] - axis_e_2 = [0, 1.4, 2.4, 3.4, 9] - axis_c_2 = [5, 7, 9, 12] + global_axis_1 = [1, 2, 3, 5] + model_axis_1 = [5, 7] + global_axis_2 = [0, 1.4, 2.4, 3.4, 9] + model_axis_2 = [5, 7, 9, 12] data = { "dataset1": xr.DataArray( - np.ones((4, 2)), coords=[("e", axis_e_1), ("c", axis_c_1)] + np.ones((4, 2)), coords=[("global", global_axis_1), ("model", model_axis_1)] ).to_dataset(name="data"), "dataset2": xr.DataArray( - np.ones((5, 4)), coords=[("e", axis_e_2), ("c", axis_c_2)] + np.ones((5, 4)), coords=[("global", global_axis_2), ("model", model_axis_2)] ).to_dataset(name="data"), } @@ -150,19 +150,19 @@ def test_multi_dataset_overlap(): assert all(p.data.size == 4 for p in bag[:1]) assert all(p.descriptor[0].label == "dataset1" for p in bag[1:5]) - assert all(all(p.descriptor[0].axis["c"] == axis_c_1) for p in bag[1:5]) - assert all(all(p.descriptor[0].axis["e"] == axis_e_1) for p in bag[1:5]) - assert [p.descriptor[0].indices["e"] for p in bag[1:5]] == [0, 1, 2, 3] + assert all(all(p.descriptor[0].axis["model"] == model_axis_1) for p in bag[1:5]) + assert all(all(p.descriptor[0].axis["global"] == global_axis_1) for p in bag[1:5]) + assert [p.descriptor[0].indices["global"] for p in bag[1:5]] == [0, 1, 2, 3] assert all(p.data.size == 6 for p in bag[1:4]) assert all(p.descriptor[1].label == "dataset2" for p in bag[1:4]) - assert all(all(p.descriptor[1].axis["c"] == axis_c_2) for p in bag[1:4]) - assert all(all(p.descriptor[1].axis["e"] == axis_e_2) for p in bag[1:4]) - assert [p.descriptor[1].indices["e"] for p in bag[1:4]] == [1, 2, 3] + assert all(all(p.descriptor[1].axis["model"] == model_axis_2) for p in bag[1:4]) + assert all(all(p.descriptor[1].axis["global"] == global_axis_2) for p in bag[1:4]) + assert [p.descriptor[1].indices["global"] for p in bag[1:4]] == [1, 2, 3] assert all(p.data.size == 4 for p in bag[5:]) assert bag[4].descriptor[0].label == "dataset1" assert bag[5].descriptor[0].label == "dataset2" - assert np.array_equal(bag[4].descriptor[0].axis["c"], axis_c_1) - assert np.array_equal(bag[5].descriptor[0].axis["c"], axis_c_2) - assert [p.descriptor[0].indices["e"] for p in bag[1:4]] == [0, 1, 2] + assert np.array_equal(bag[4].descriptor[0].axis["model"], model_axis_1) + assert np.array_equal(bag[5].descriptor[0].axis["model"], model_axis_2) + assert [p.descriptor[0].indices["global"] for p in bag[1:4]] == [0, 1, 2] diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 6f3acc4db..9d29e9efa 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -24,23 +24,18 @@ ) @pytest.mark.parametrize( "suite", - # MultichannelMulticomponentDecay], [OneCompartmentDecay, TwoCompartmentDecay, ThreeDatasetDecay, MultichannelMulticomponentDecay], ) def test_optimization(suite, index_dependent, grouped, weight, method): model = suite.model - model.is_grouped = grouped model.megacomplex["m1"].is_index_dependent = index_dependent print("Grouped:", grouped) print("Index dependent:", index_dependent) - assert model.grouped() == grouped - sim_model = suite.sim_model - sim_model.is_grouped = grouped - sim_model.is_index_dependent = index_dependent + sim_model.megacomplex["m1"].is_index_dependent = index_dependent print(model.validate()) assert model.valid() @@ -65,11 +60,14 @@ def test_optimization(suite, index_dependent, grouped, weight, method): nr_datasets = 3 if issubclass(suite, ThreeDatasetDecay) else 1 data = {} for i in range(nr_datasets): - e_axis = getattr(suite, "e_axis" if i == 0 else f"e_axis{i+1}") - c_axis = getattr(suite, "c_axis" if i == 0 else f"c_axis{i+1}") + global_axis = getattr(suite, "global_axis" if i == 0 else f"global_axis{i+1}") + model_axis = getattr(suite, "model_axis" if i == 0 else f"model_axis{i+1}") dataset = simulate( - sim_model, f"dataset{i+1}", wanted_parameters, {"e": e_axis, "c": c_axis} + sim_model, + f"dataset{i+1}", + wanted_parameters, + {"global": global_axis, "model": model_axis}, ) print(f"Dataset {i+1}") print("=============") @@ -80,10 +78,10 @@ def test_optimization(suite, index_dependent, grouped, weight, method): if weight: dataset["weight"] = xr.DataArray( - np.ones_like(dataset.data) * 0.5, coords=dataset.coords + np.ones_like(dataset.data) * 0.5, coords=dataset.data.coords ) - assert dataset.data.shape == (c_axis.size, e_axis.size) + assert dataset.data.shape == (model_axis.size, global_axis.size) data[f"dataset{i+1}"] = dataset @@ -92,6 +90,7 @@ def test_optimization(suite, index_dependent, grouped, weight, method): parameters=initial_parameters, data=data, maximum_number_function_evaluations=10, + group=grouped, group_tolerance=0.1, optimization_method=method, ) @@ -118,8 +117,8 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert "residual_left_singular_vectors" in resultdata assert "residual_right_singular_vectors" in resultdata assert "residual_singular_values" in resultdata - assert np.array_equal(dataset.c, resultdata.c) - assert np.array_equal(dataset.e, resultdata.e) + assert np.array_equal(dataset.coords["model"], resultdata.coords["model"]) + assert np.array_equal(dataset.coords["global"], resultdata.coords["global"]) assert dataset.data.shape == resultdata.data.shape print(dataset.data[0, 0], resultdata.data[0, 0]) assert np.allclose(dataset.data, resultdata.data) @@ -131,15 +130,3 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert "weighted_residual_left_singular_vectors" in resultdata assert "weighted_residual_right_singular_vectors" in resultdata assert "weighted_residual_singular_values" in resultdata - - # assert callable(model.additional_penalty_function) - # assert model.additional_penalty_function_called - # - # if isinstance(model, DecayModel): - # assert callable(model.constrain_matrix_function) - # assert model.constrain_matrix_function_called - # assert callable(model.retrieve_clp_function) - # assert model.retrieve_clp_function_called - # else: - # assert not model.constrain_matrix_function_called - # assert not model.retrieve_clp_function_called diff --git a/glotaran/analysis/test/test_penalties.py b/glotaran/analysis/test/test_penalties.py index 72019a0d7..ddee6565c 100644 --- a/glotaran/analysis/test/test_penalties.py +++ b/glotaran/analysis/test/test_penalties.py @@ -31,14 +31,14 @@ def test_constraint(index_dependent, grouped): ) parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) - e_axis = np.arange(50) + global_axis = np.arange(50) print("grouped", grouped, "index_dependent", index_dependent) dataset = simulate( suite.sim_model, "dataset1", parameters, - {"e": e_axis, "c": suite.c_axis}, + {"global": global_axis, "model": suite.model_axis}, ) scheme = Scheme(model=model, parameters=parameters, data={"dataset1": dataset}) problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) @@ -49,5 +49,5 @@ def test_constraint(index_dependent, grouped): assert isinstance(problem.full_penalty, np.ndarray) assert ( problem.full_penalty.size - == (suite.c_axis.size * e_axis.size) + problem.additional_penalty.size + == (suite.model_axis.size * global_axis.size) + problem.additional_penalty.size ) diff --git a/glotaran/analysis/test/test_problem.py b/glotaran/analysis/test/test_problem.py index 30d3c60c4..99efd16e7 100644 --- a/glotaran/analysis/test/test_problem.py +++ b/glotaran/analysis/test/test_problem.py @@ -20,17 +20,18 @@ def problem(request) -> Problem: model = suite.model model.megacomplex["m1"].is_index_dependent = request.param[1] - model.is_grouped = request.param[0] model.is_index_dependent = request.param[1] dataset = simulate( suite.sim_model, "dataset1", suite.wanted_parameters, - {"e": suite.e_axis, "c": suite.c_axis}, + {"global": suite.global_axis, "model": suite.model_axis}, ) scheme = Scheme(model=model, parameters=suite.initial_parameters, data={"dataset1": dataset}) - return GroupedProblem(scheme) if request.param[0] else UngroupedProblem(scheme) + problem = GroupedProblem(scheme) if request.param[0] else UngroupedProblem(scheme) + problem.grouped = request.param[0] + return problem def test_problem_bag(problem: Problem): @@ -39,7 +40,7 @@ def test_problem_bag(problem: Problem): if problem.grouped: assert isinstance(bag, collections.deque) - assert len(bag) == suite.e_axis.size + assert len(bag) == suite.global_axis.size assert problem.groups == {"dataset1": ["dataset1"]} else: assert isinstance(bag, dict) @@ -52,7 +53,7 @@ def test_problem_matrices(problem: Problem): if problem.grouped: if problem.model.is_index_dependent: assert all(isinstance(m, xr.DataArray) for m in problem.reduced_matrices) - assert len(problem.reduced_matrices) == suite.e_axis.size + assert len(problem.reduced_matrices) == suite.global_axis.size else: assert "dataset1" in problem.reduced_matrices assert isinstance(problem.reduced_matrices["dataset1"], xr.DataArray) @@ -69,63 +70,61 @@ def test_problem_matrices(problem: Problem): def test_problem_residuals(problem: Problem): - print("Grouped", problem.model.is_grouped, "Indexdep", problem.model.is_index_dependent) problem.calculate_residual() if problem.grouped: assert isinstance(problem.residuals, list) assert all(isinstance(r, np.ndarray) for r in problem.residuals) - assert len(problem.residuals) == suite.e_axis.size + assert len(problem.residuals) == suite.global_axis.size else: assert isinstance(problem.residuals, dict) assert "dataset1" in problem.residuals assert all(isinstance(r, xr.DataArray) for r in problem.residuals["dataset1"]) - assert len(problem.residuals["dataset1"]) == suite.e_axis.size + assert len(problem.residuals["dataset1"]) == suite.global_axis.size assert isinstance(problem.reduced_clps, dict) assert "dataset1" in problem.reduced_clps assert all(isinstance(c, xr.DataArray) for c in problem.reduced_clps["dataset1"]) - assert len(problem.reduced_clps["dataset1"]) == suite.e_axis.size + assert len(problem.reduced_clps["dataset1"]) == suite.global_axis.size assert isinstance(problem.clps, dict) assert "dataset1" in problem.clps assert all(isinstance(c, xr.DataArray) for c in problem.clps["dataset1"]) - assert len(problem.clps["dataset1"]) == suite.e_axis.size + assert len(problem.clps["dataset1"]) == suite.global_axis.size def test_problem_result_data(problem: Problem): - print("Grouped", problem.model.is_grouped, "Indexdep", problem.model.is_index_dependent) data = problem.create_result_data() label = "dataset1" assert label in data dataset = data[label] - dataset_model = problem.filled_dataset_descriptors[label] + dataset_model = problem.dataset_models[label] assert "clp_label" in dataset.coords assert np.array_equal(dataset.clp_label, ["s1", "s2", "s3", "s4"]) assert dataset_model.get_global_dimension() in dataset.coords - assert np.array_equal(dataset.coords[dataset_model.get_global_dimension()], suite.e_axis) + assert np.array_equal(dataset.coords[dataset_model.get_global_dimension()], suite.global_axis) assert dataset_model.get_model_dimension() in dataset.coords - assert np.array_equal(dataset.coords[dataset_model.get_model_dimension()], suite.c_axis) + assert np.array_equal(dataset.coords[dataset_model.get_model_dimension()], suite.model_axis) assert "matrix" in dataset matrix = dataset.matrix if problem.model.is_index_dependent: assert len(matrix.shape) == 3 - assert matrix.shape[0] == suite.e_axis.size - assert matrix.shape[1] == suite.c_axis.size + assert matrix.shape[0] == suite.global_axis.size + assert matrix.shape[1] == suite.model_axis.size assert matrix.shape[2] == 4 else: assert len(matrix.shape) == 2 - assert matrix.shape[0] == suite.c_axis.size + assert matrix.shape[0] == suite.model_axis.size assert matrix.shape[1] == 4 assert "clp" in dataset clp = dataset.clp assert len(clp.shape) == 2 - assert clp.shape[0] == suite.e_axis.size + assert clp.shape[0] == suite.global_axis.size assert clp.shape[1] == 4 assert "weighted_residual" in dataset @@ -166,8 +165,8 @@ def test_prepare_data(): dataset = xr.DataArray( np.ones((global_axis.size, model_axis.size)), - coords={"e": global_axis, "c": model_axis}, - dims=("e", "c"), + coords={"global": global_axis, "model": model_axis}, + dims=("global", "model"), ) scheme = Scheme(model, parameters, {"dataset1": dataset}) @@ -180,8 +179,8 @@ def test_prepare_data(): assert data.data.shape == (model_axis.size, global_axis.size) assert data.data.shape == data.weight.shape - assert np.all(data.weight.sel(e=slice(0, 200), c=slice(4, 8)).values == 0.5) - assert np.all(data.weight.sel(c=slice(0, 3)).values == 1) + assert np.all(data.weight.sel({"global": slice(0, 200), "model": slice(4, 8)}).values == 0.5) + assert np.all(data.weight.sel(model=slice(0, 3)).values == 1) model_dict["weights"].append( { @@ -196,8 +195,10 @@ def test_prepare_data(): scheme = Scheme(model, parameters, {"dataset1": dataset}) problem = Problem(scheme) data = problem.data["dataset1"] - assert np.all(data.weight.sel(e=slice(0, 200), c=slice(4, 8)).values == 0.5 * 0.2) - assert np.all(data.weight.sel(c=slice(0, 3)).values == 0.2) + assert np.all( + data.weight.sel({"global": slice(0, 200), "model": slice(4, 8)}).values == 0.5 * 0.2 + ) + assert np.all(data.weight.sel(model=slice(0, 3)).values == 0.2) with pytest.warns( UserWarning, diff --git a/glotaran/analysis/test/test_relations.py b/glotaran/analysis/test/test_relations.py index 2cdbde279..c0841efd3 100644 --- a/glotaran/analysis/test/test_relations.py +++ b/glotaran/analysis/test/test_relations.py @@ -24,7 +24,7 @@ def test_constraint(index_dependent, grouped): suite.sim_model, "dataset1", parameters, - {"e": suite.e_axis, "c": suite.c_axis}, + {"global": suite.global_axis, "model": suite.model_axis}, ) scheme = Scheme(model=model, parameters=parameters, data={"dataset1": dataset}) problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) diff --git a/glotaran/analysis/test/test_simulation.py b/glotaran/analysis/test/test_simulation.py index fc802b52c..f9e246e99 100644 --- a/glotaran/analysis/test/test_simulation.py +++ b/glotaran/analysis/test/test_simulation.py @@ -1,17 +1,24 @@ import numpy as np +import pytest from glotaran.analysis.simulation import simulate from glotaran.analysis.test.models import SimpleTestModel from glotaran.parameter import ParameterGroup -def test_simulate_dataset(): +@pytest.mark.parametrize("index_dependent", [True, False]) +@pytest.mark.parametrize("noise", [True, False]) +def test_simulate_dataset(index_dependent, noise): model = SimpleTestModel.from_dict( { - "megacomplex": {"m1": {"is_index_dependent": False}}, + "megacomplex": { + "m1": {"is_index_dependent": index_dependent}, + "m2": {"type": "global_complex"}, + }, "dataset": { "dataset1": { "megacomplex": ["m1"], + "global_megacomplex": ["m2"], }, }, } @@ -23,21 +30,29 @@ def test_simulate_dataset(): print(model.validate(parameter)) assert model.valid(parameter) - est_axis = np.asarray([1, 1, 1, 1]) - cal_axis = np.asarray([2, 2, 2]) + global_axis = np.asarray([1, 1, 1, 1]) + model_axis = np.asarray([2, 2, 2]) - data = simulate(model, "dataset1", parameter, {"e": est_axis, "c": cal_axis}) - assert np.array_equal(data["c"], cal_axis) - assert np.array_equal(data["e"], est_axis) - assert data.data.shape == (3, 4) - assert np.array_equal( - data.data, - np.asarray( - [ - [2, 4, 6], - [4, 10, 16], - [6, 16, 26], - [8, 22, 36], - ] - ).T, + data = simulate( + model, + "dataset1", + parameter, + {"global": global_axis, "model": model_axis}, + noise=noise, + noise_std_dev=0.1, ) + assert np.array_equal(data["global"], global_axis) + assert np.array_equal(data["model"], model_axis) + assert data.data.shape == (3, 4) + if not noise: + assert np.array_equal( + data.data, + np.asarray( + [ + [2, 4, 6], + [4, 10, 16], + [6, 16, 26], + [8, 22, 36], + ] + ).T, + ) diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index f1ee25df9..dc857e6dd 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -6,7 +6,7 @@ import numpy as np import xarray as xr -from glotaran.model import DatasetDescriptor +from glotaran.model import DatasetModel from glotaran.model import Model from glotaran.parameter import ParameterGroup @@ -40,13 +40,20 @@ def get_min_max_from_interval(interval, axis): def calculate_matrix( - dataset_descriptor: DatasetDescriptor, + dataset_model: DatasetModel, indices: dict[str, int] | None, + global_model: bool = False, ) -> xr.DataArray: matrix = None - for scale, megacomplex in dataset_descriptor.iterate_megacomplexes(): - this_matrix = megacomplex.calculate_matrix(dataset_descriptor, indices) + megacomplex_iterator = dataset_model.iterate_megacomplexes + + if global_model: + megacomplex_iterator = dataset_model.iterate_global_megacomplexes + dataset_model.swap_dimensions() + + for scale, megacomplex in megacomplex_iterator(): + this_matrix = megacomplex.calculate_matrix(dataset_model, indices) if scale is not None: this_matrix *= scale @@ -58,6 +65,9 @@ def calculate_matrix( matrix = matrix.fillna(0) matrix += this_matrix.fillna(0) + if global_model: + dataset_model.swap_dimensions() + return matrix @@ -143,13 +153,13 @@ def retrieve_clps( clps = xr.DataArray(np.zeros((clp_labels.size), dtype=np.float64), coords=[clp_labels]) clps.loc[{"clp_label": reduced_clps.coords["clp_label"]}] = reduced_clps.values - print("ret", clps) for relation in model.relations: relation = relation.fill(model, parameters) - print("YYY", relation.target, relation.source, relation.parameter) - if relation.target in clp_labels and relation.applies(index): - if relation.source not in clp_labels: - continue + if ( + relation.target in clp_labels + and relation.applies(index) + and relation.source in clp_labels + ): clps.loc[{"clp_label": relation.target}] = relation.parameter * clps.sel( clp_label=relation.source ) diff --git a/glotaran/builtin/io/yml/sanatize.py b/glotaran/builtin/io/yml/sanatize.py index 706e2f537..fe3084656 100644 --- a/glotaran/builtin/io/yml/sanatize.py +++ b/glotaran/builtin/io/yml/sanatize.py @@ -3,6 +3,8 @@ from typing import Tuple from typing import Union +from glotaran.deprecation import warn_deprecated + # tuple_pattern = re.compile(r"(\(.*?,.*?\))") tuple_number_pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))") number_pattern = re.compile(r"[\d.+-]+") @@ -160,3 +162,76 @@ def sanitize_yaml(d: dict, do_keys: bool = True, do_values: bool = False) -> dic # this is only needed to allow for tuple parsing in specification sanitize_dict_values(d) return d + + +def check_deprecations(spec: dict): + if "type" in spec: + if spec["type"] == "kinetic-spectrum": + warn_deprecated( + deprecated_qual_name_usage="type: kinectic-spectrum", + new_qual_name_usage="default-megacomplex: decay", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + spec["default-megacomplex"] = "decay" + elif spec["type"] == "spectral": + warn_deprecated( + deprecated_qual_name_usage="type: spectral", + new_qual_name_usage="default-megacomplex: spectral", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + spec["default-megacomplex"] = "spectral" + del spec["type"] + + if "spectral_relations" in spec: + warn_deprecated( + deprecated_qual_name_usage="spectral_relations", + new_qual_name_usage="relations", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + spec["relations"] = spec["spectral_relations"] + del spec["spectral_relations"] + + for i, relation in enumerate(spec["relations"]): + if "compartment" in relation: + warn_deprecated( + deprecated_qual_name_usage="relation.compartment", + new_qual_name_usage="relation.source", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + relation["source"] = relation["compartment"] + del relation["compartment"] + + if "spectral_constraints" in spec: + warn_deprecated( + deprecated_qual_name_usage="spectral_constraints", + new_qual_name_usage="constraints", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + spec["constraints"] = spec["spectral_constraints"] + del spec["spectral_constraints"] + + for i, constraint in enumerate(spec["constraints"]): + if "compartment" in constraint: + warn_deprecated( + deprecated_qual_name_usage="constraint.compartment", + new_qual_name_usage="constraint.target", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + constraint["target"] = constraint["compartment"] + del constraint["compartment"] + + if "equal_area_penalties" in spec: + warn_deprecated( + deprecated_qual_name_usage="equal_area_penalties", + new_qual_name_usage="clp_area_penalties", + to_be_removed_in_version="0.6.0", + check_qual_names=(False, False), + ) + spec["clp_area_penalties"] = spec["equal_area_penalties"] + del spec["equal_area_penalties"] diff --git a/glotaran/builtin/io/yml/test/test_model_parser_kinetic.py b/glotaran/builtin/io/yml/test/test_model_parser.py similarity index 63% rename from glotaran/builtin/io/yml/test/test_model_parser_kinetic.py rename to glotaran/builtin/io/yml/test/test_model_parser.py index 0d97ab411..b19a48dab 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser_kinetic.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -5,18 +5,17 @@ import numpy as np import pytest -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.builtin.models.kinetic_image.kinetic_decay_megacomplex import KineticDecayMegacomplex -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_dataset_descriptor import ( - KineticSpectrumDatasetDescriptor, -) -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import KineticSpectrumModel -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import ZeroConstraint -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import EqualAreaPenalty -from glotaran.builtin.models.kinetic_spectrum.spectral_shape import SpectralShapeGaussian +from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex +from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeSkewedGaussian from glotaran.io import load_model +from glotaran.model import DatasetModel +from glotaran.model import Model from glotaran.model import Weight +from glotaran.model.clp_penalties import EqualAreaPenalty +from glotaran.model.constraint import OnlyConstraint +from glotaran.model.constraint import ZeroConstraint from glotaran.parameter import ParameterGroup THIS_DIR = dirname(abspath(__file__)) @@ -24,15 +23,17 @@ @pytest.fixture def model(): - spec_path = join(THIS_DIR, "test_model_spec_kinetic.yml") + spec_path = join(THIS_DIR, "test_model_spec.yml") m = load_model(spec_path) print(m.markdown()) return m def test_correct_model(model): - assert type(model).__name__ == "KineticSpectrumModel" - assert isinstance(model, KineticSpectrumModel) + assert isinstance(model, Model) + assert "decay" == model.default_megacomplex + assert "decay" in model.megacomplex_types + assert "spectral" in model.megacomplex_types def test_dataset(model): @@ -40,37 +41,33 @@ def test_dataset(model): assert "dataset1" in model.dataset dataset = model.dataset["dataset1"] - assert isinstance(dataset, KineticSpectrumDatasetDescriptor) + assert isinstance(dataset, DatasetModel) assert dataset.label == "dataset1" assert dataset.megacomplex == ["cmplx1"] assert dataset.initial_concentration == "inputD1" assert dataset.irf == "irf1" assert dataset.scale == 1 - assert len(dataset.shape) == 2 - assert dataset.shape["s1"] == "shape1" - assert dataset.shape["s2"] == "shape2" - dataset = model.dataset["dataset2"] +def test_constraints(model): + print(model.constraints) + assert len(model.constraints) == 2 + zero = model.constraints[0] + assert isinstance(zero, ZeroConstraint) + assert zero.target == "s1" + assert zero.interval == [[1, 100], [2, 200]] -def test_spectral_constraints(model): - print(model.spectral_constraints) - assert len(model.spectral_constraints) == 2 + only = model.constraints[1] + assert isinstance(only, OnlyConstraint) + assert only.target == "s1" + assert only.interval == [[1, 100], [2, 200]] - assert any(isinstance(c, ZeroConstraint) for c in model.spectral_constraints) - zcs = [zc for zc in model.spectral_constraints if zc.type == "zero"] - assert len(zcs) == 2 - for zc in zcs: - assert zc.compartment == "s1" - assert zc.interval == [[1, 100], [2, 200]] - - -def test_spectral_penalties(model): - assert len(model.equal_area_penalties) == 1 - assert all(isinstance(c, EqualAreaPenalty) for c in model.equal_area_penalties) - eac = model.equal_area_penalties[0] +def test_penalties(model): + assert len(model.clp_area_penalties) == 1 + assert all(isinstance(c, EqualAreaPenalty) for c in model.clp_area_penalties) + eac = model.clp_area_penalties[0] assert eac.source == "s3" assert eac.source_intervals == [[670, 810]] assert eac.target == "s2" @@ -79,13 +76,13 @@ def test_spectral_penalties(model): assert eac.weight == 0.0016 -def test_spectral_relations(model): - print(model.spectral_relations) - assert len(model.spectral_relations) == 1 +def test_relations(model): + print(model.relations) + assert len(model.relations) == 1 - rel = model.spectral_relations[0] + rel = model.relations[0] - assert rel.compartment == "s1" + assert rel.source == "s1" assert rel.target == "s2" assert rel.interval == [[1, 100], [2, 200]] @@ -157,7 +154,7 @@ def test_shapes(model): assert "shape1" in model.shape shape = model.shape["shape1"] - assert isinstance(shape, SpectralShapeGaussian) + assert isinstance(shape, SpectralShapeSkewedGaussian) assert shape.amplitude.full_label == "shape.1" assert shape.location.full_label == "shape.2" assert shape.width.full_label == "shape.3" @@ -166,10 +163,16 @@ def test_shapes(model): def test_megacomplexes(model): assert len(model.megacomplex) == 3 - for i, _ in enumerate(model.megacomplex, start=1): + for i in range(1, 3): label = f"cmplx{i}" assert label in model.megacomplex megacomplex = model.megacomplex[label] - assert isinstance(megacomplex, KineticDecayMegacomplex) + assert isinstance(megacomplex, DecayMegacomplex) assert megacomplex.label == label assert megacomplex.k_matrix == [f"km{i}"] + + assert "cmplx3" in model.megacomplex + megacomplex = model.megacomplex["cmplx3"] + assert len(megacomplex.shape) == 2 + assert megacomplex.shape["s1"] == "shape1" + assert megacomplex.shape["s2"] == "shape2" diff --git a/glotaran/builtin/io/yml/test/test_model_spec_kinetic.yml b/glotaran/builtin/io/yml/test/test_model_spec.yml similarity index 72% rename from glotaran/builtin/io/yml/test/test_model_spec_kinetic.yml rename to glotaran/builtin/io/yml/test/test_model_spec.yml index f34b73ed4..499e61b8f 100644 --- a/glotaran/builtin/io/yml/test/test_model_spec_kinetic.yml +++ b/glotaran/builtin/io/yml/test/test_model_spec.yml @@ -1,4 +1,4 @@ -type: kinetic-spectrum +default-megacomplex: decay dataset: @@ -7,9 +7,6 @@ dataset: initial_concentration: inputD1 irf: irf1 scale: 1 - shape: - s1: shape1 - s2: shape2 dataset2: megacomplex: [cmplx2] initial_concentration: inputD2 @@ -21,7 +18,18 @@ irf: type: gaussian center: [1] width: [2] - irf2: [spectral-gaussian, [1, 2], [3,4], [9], None, false, true, 55, 55, [5,6], [7,8], true] # compact + irf2: + type: spectral-gaussian + center: [1, 2] + width: [3,4] + scale: [9] + normalize: false + backsweep: true + backsweep_period: 55 + dispersion_center: 55 + center_dispersion: [5,6] + width_dispersion: [7,8] + model_dispersion_with_wavenumber: true initial_concentration: inputD1: @@ -46,7 +54,7 @@ k_matrix: shape: shape1: - type: "gaussian" + type: "skewed-gaussian" amplitude: shape.1 location: shape.2 width: shape.3 @@ -57,18 +65,24 @@ megacomplex: cmplx2: k_matrix: [km2] cmplx3: - type: "kinetic-decay" - k_matrix: [km3] + type: "spectral" + shape: + s1: shape1 + s2: shape2 -spectral_constraints: +constraints: - type: zero - compartment: s1 + target: s1 + interval: + - [1, 100] + - [2, 200] + - type: only + target: s1 interval: - [1, 100] - [2, 200] - - [zero, s1, [[1, 100], [2, 200]]] -equal_area_penalties: +clp_area_penalties: - type: equal_area source: s3 source_intervals: [[670, 810]] @@ -77,8 +91,8 @@ equal_area_penalties: parameter: 55 weight: 0.0016 -spectral_relations: - - compartment: s1 +relations: + - source: s1 target: s2 parameter: 8 interval: [[1,100], [2,200]] diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index b94b05389..63fe29884 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -7,6 +7,7 @@ import yaml +from glotaran.builtin.io.yml.sanatize import check_deprecations from glotaran.builtin.io.yml.sanatize import sanitize_yaml from glotaran.io import ProjectIoInterface from glotaran.io import load_dataset @@ -15,13 +16,13 @@ from glotaran.io import register_project_io from glotaran.io import save_dataset from glotaran.io import save_parameters -from glotaran.model import get_model +from glotaran.model import Model +from glotaran.model import get_megacomplex from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme if TYPE_CHECKING: - from glotaran.model import Model from glotaran.project import Result @@ -48,16 +49,35 @@ def load_model(self, file_name: str) -> Model: with open(file_name) as f: spec = yaml.safe_load(f) - spec = sanitize_yaml(spec) - - if "type" not in spec: - raise Exception("Model type not defined") + check_deprecations(spec) - model_type = spec["type"] - del spec["type"] + spec = sanitize_yaml(spec) - model = get_model(model_type) - return model.from_dict(spec) + default_megacomplex = spec.get("default-megacomplex") + + if default_megacomplex is None and any( + "type" not in m for m in spec["megacomplex"].values() + ): + raise ValueError( + "Default megacomplex is not defined in model and " + "at least one megacomplex does not have a type." + ) + + if "megacomplex" not in spec: + raise ValueError("No megacomplex defined in model") + + megacomplex_types = { + m["type"]: get_megacomplex(m["type"]) + for m in spec["megacomplex"].values() + if "type" in m + } + if default_megacomplex is not None: + megacomplex_types[default_megacomplex] = get_megacomplex(default_megacomplex) + del spec["default-megacomplex"] + + return Model.from_dict( + spec, megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex + ) def load_parameters(self, file_name: str) -> ParameterGroup: @@ -122,6 +142,7 @@ def load_scheme(self, file_name: str) -> Scheme: ftol = scheme.get("ftol", 1e-8) gtol = scheme.get("gtol", 1e-8) xtol = scheme.get("xtol", 1e-8) + group = scheme.get("group", False) group_tolerance = scheme.get("group_tolerance", 0.0) saving = SavingOptions(**scheme.get("saving", {})) return Scheme( @@ -133,6 +154,7 @@ def load_scheme(self, file_name: str) -> Scheme: ftol=ftol, gtol=gtol, xtol=xtol, + group=group, group_tolerance=group_tolerance, optimization_method=optimization_method, saving=saving, @@ -156,6 +178,7 @@ def save_result(self, result: Result, result_path: str): scheme_path = os.path.join(result_path, "scheme.yml") result_scheme = dataclasses.replace(result.scheme) + result_scheme.model = result_scheme.model.markdown() result = dataclasses.replace(result) result.scheme = scheme_path diff --git a/glotaran/builtin/models/kinetic_image/test/__init__.py b/glotaran/builtin/megacomplexes/__init__.py similarity index 100% rename from glotaran/builtin/models/kinetic_image/test/__init__.py rename to glotaran/builtin/megacomplexes/__init__.py diff --git a/glotaran/builtin/megacomplexes/baseline/__init__.py b/glotaran/builtin/megacomplexes/baseline/__init__.py new file mode 100644 index 000000000..54aab0229 --- /dev/null +++ b/glotaran/builtin/megacomplexes/baseline/__init__.py @@ -0,0 +1 @@ +from glotaran.builtin.megacomplexes.baseline.baseline_megacomplex import BaselineMegacomplex diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py new file mode 100644 index 000000000..942384ed9 --- /dev/null +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import numpy as np +import xarray as xr + +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import megacomplex + + +@megacomplex(unique=True, register_as="baseline") +class BaselineMegacomplex(Megacomplex): + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coordinates()[model_dimension] + clp_label = [f"{dataset_model.label}_baseline"] + matrix = np.ones((model_axis.size, 1), dtype=np.float64) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis.data), ("clp_label", clp_label)) + ) + + def index_dependent(self, dataset: DatasetModel) -> bool: + return False + + def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + data[f"{dataset_model.label}_baseline"] = data.clp.sel(clp_label="baseline") diff --git a/glotaran/builtin/models/kinetic_image/test/test_baseline.py b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py similarity index 75% rename from glotaran/builtin/models/kinetic_image/test/test_baseline.py rename to glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py index c855d1103..ccff81bc1 100644 --- a/glotaran/builtin/models/kinetic_image/test/test_baseline.py +++ b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py @@ -2,19 +2,21 @@ import xarray as xr from glotaran.analysis.util import calculate_matrix -from glotaran.builtin.models.kinetic_image import KineticImageModel +from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Model from glotaran.parameter import ParameterGroup def test_baseline(): - model = KineticImageModel.from_dict( + model = Model.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, }, "megacomplex": { - "mc1": {"type": "kinetic-decay", "k_matrix": ["k1"]}, - "mc2": {"type": "kinetic-baseline"}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, + "mc2": {"type": "baseline", "dimension": "time"}, }, "k_matrix": { "k1": { @@ -29,7 +31,8 @@ def test_baseline(): "megacomplex": ["mc1", "mc2"], }, }, - } + }, + megacomplex_types={"decay": DecayMegacomplex, "baseline": BaselineMegacomplex}, ) parameter = ParameterGroup.from_list( @@ -45,7 +48,7 @@ def test_baseline(): coords = {"time": time, "pixel": pixel} dataset_model = model.dataset["dataset1"].fill(model, parameter) dataset_model.overwrite_global_dimension("pixel") - dataset_model.set_coords(coords) + dataset_model.set_coordinates(coords) matrix = calculate_matrix(dataset_model, {}) compartments = matrix.coords["clp_label"] diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/__init__.py b/glotaran/builtin/megacomplexes/coherent_artifact/__init__.py new file mode 100644 index 000000000..6810c1d97 --- /dev/null +++ b/glotaran/builtin/megacomplexes/coherent_artifact/__init__.py @@ -0,0 +1,3 @@ +from glotaran.builtin.megacomplexes.coherent_artifact.coherent_artifact_megacomplex import ( + CoherentArtifactMegacomplex, +) diff --git a/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py similarity index 64% rename from glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py rename to glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index 5d4e86d61..eb22c3f64 100644 --- a/glotaran/builtin/models/kinetic_spectrum/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -5,8 +5,9 @@ import numpy as np import xarray as xr -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.model import DatasetDescriptor +from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.model import DatasetModel from glotaran.model import Megacomplex from glotaran.model import ModelError from glotaran.model import megacomplex @@ -14,16 +15,21 @@ @megacomplex( - "time", + dimension="time", + unique=True, properties={ "order": {"type": int}, "width": {"type": Parameter, "allow_none": True}, }, + dataset_model_items={ + "irf": {"type": Irf, "allow_none": True}, + }, + register_as="coherent-artifact", ) class CoherentArtifactMegacomplex(Megacomplex): def calculate_matrix( self, - dataset_model: DatasetDescriptor, + dataset_model: DatasetModel, indices: dict[str, int], **kwargs, ): @@ -38,9 +44,9 @@ def calculate_matrix( global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) - global_axis = dataset_model.get_coords().get(global_dimension).values + global_axis = dataset_model.get_coordinates().get(global_dimension).values model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coords()[model_dimension].values + model_axis = dataset_model.get_coordinates()[model_dimension].values irf = dataset_model.irf @@ -56,9 +62,22 @@ def calculate_matrix( def compartments(self): return [f"coherent_artifact_{i}" for i in range(1, self.order + 1)] - def index_dependent(self, dataset: DatasetDescriptor) -> bool: + def index_dependent(self, dataset: DatasetModel) -> bool: return False + def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + global_dimension = dataset_model.get_global_dimension() + model_dimension = dataset_model.get_model_dimension() + data.coords["coherent_artifact_order"] = list(range(1, self.order + 1)) + data["coherent_artifact_concentration"] = ( + (model_dimension, "coherent_artifact_order"), + data.matrix.sel(clp_label=self.compartments()).values, + ) + data["coherent_artifact_associated_spectra"] = ( + (global_dimension, "coherent_artifact_order"), + data.clp.sel(clp_label=self.compartments()).values, + ) + @nb.jit(nopython=True, parallel=True) def _calculate_coherent_artifact_matrix(center, width, axis, order): diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py similarity index 85% rename from glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py rename to glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py index 893f019b0..f09697b98 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_coherent_artifact.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py @@ -4,7 +4,9 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate from glotaran.analysis.util import calculate_matrix -from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel +from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -15,7 +17,7 @@ def test_coherent_artifact(): "j1": {"compartments": ["s1"], "parameters": ["2"]}, }, "megacomplex": { - "mc1": {"type": "kinetic-decay", "k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, "mc2": {"type": "coherent-artifact", "order": 3}, }, "k_matrix": { @@ -40,7 +42,13 @@ def test_coherent_artifact(): }, }, } - model = KineticSpectrumModel.from_dict(model_dict.copy()) + model = Model.from_dict( + model_dict.copy(), + megacomplex_types={ + "decay": DecayMegacomplex, + "coherent-artifact": CoherentArtifactMegacomplex, + }, + ) parameters = ParameterGroup.from_list( [ @@ -57,7 +65,7 @@ def test_coherent_artifact(): dataset_model = model.dataset["dataset1"].fill(model, parameters) dataset_model.overwrite_global_dimension("spectral") - dataset_model.set_coords(coords) + dataset_model.set_coordinates(coords) matrix = calculate_matrix(dataset_model, {}) compartments = matrix.coords["clp_label"].values @@ -83,7 +91,7 @@ def test_coherent_artifact(): ), ], ) - axis = {"time": time, "spectral": clp.spectral.data} + axis = {"time": time, "spectral": clp.spectral} data = simulate(model, "dataset1", parameters, axis, clp) dataset = {"dataset1": data} diff --git a/glotaran/builtin/megacomplexes/decay/__init__.py b/glotaran/builtin/megacomplexes/decay/__init__.py new file mode 100644 index 000000000..66203628a --- /dev/null +++ b/glotaran/builtin/megacomplexes/decay/__init__.py @@ -0,0 +1 @@ +from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py new file mode 100644 index 000000000..0ea4a0ad1 --- /dev/null +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -0,0 +1,131 @@ +"""This package contains the decay megacomplex item.""" +from __future__ import annotations + +from typing import List + +import numpy as np +import xarray as xr + +from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration +from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian +from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix +from glotaran.builtin.megacomplexes.decay.util import decay_matrix_implementation +from glotaran.builtin.megacomplexes.decay.util import retrieve_decay_associated_data +from glotaran.builtin.megacomplexes.decay.util import retrieve_irf +from glotaran.builtin.megacomplexes.decay.util import retrieve_species_associated_data +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import ModelError +from glotaran.model import megacomplex + + +@megacomplex( + dimension="time", + model_items={ + "k_matrix": List[KMatrix], + }, + properties={}, + dataset_model_items={ + "initial_concentration": {"type": InitialConcentration, "allow_none": True}, + "irf": {"type": Irf, "allow_none": True}, + }, + register_as="decay", +) +class DecayMegacomplex(Megacomplex): + """A Megacomplex with one or more K-Matrices.""" + + def has_k_matrix(self) -> bool: + return len(self.k_matrix) != 0 + + def full_k_matrix(self, model=None): + full_k_matrix = None + for k_matrix in self.k_matrix: + if model: + k_matrix = model.k_matrix[k_matrix] + if full_k_matrix is None: + full_k_matrix = k_matrix + # If multiple k matrices are present, we combine them + else: + full_k_matrix = full_k_matrix.combine(k_matrix) + return full_k_matrix + + @property + def involved_compartments(self): + return self.full_k_matrix().involved_compartments() if self.full_k_matrix() else [] + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return ( + isinstance(dataset_model.irf, IrfSpectralMultiGaussian) + and dataset_model.irf.dispersion_center is not None + ) or ( + isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.shift is not None + ) + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + if dataset_model.initial_concentration is None: + raise ModelError( + f'No initial concentration specified in dataset "{dataset_model.label}"' + ) + initial_concentration = dataset_model.initial_concentration.normalized() + + k_matrix = self.full_k_matrix() + + # we might have more species in the model then in the k matrix + species = [ + comp + for comp in initial_concentration.compartments + if comp in k_matrix.involved_compartments() + ] + + # the rates are the eigenvalues of the k matrix + rates = k_matrix.rates(initial_concentration) + + global_dimension = dataset_model.get_global_dimension() + global_index = indices.get(global_dimension) + global_axis = dataset_model.get_coordinates().get(global_dimension).values + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coordinates()[model_dimension].values + + # init the matrix + size = (model_axis.size, rates.size) + matrix = np.zeros(size, dtype=np.float64) + + decay_matrix_implementation( + matrix, rates, global_index, global_axis, model_axis, dataset_model + ) + + if not np.all(np.isfinite(matrix)): + raise ValueError( + f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" + f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" + ) + + # apply A matrix + matrix = matrix @ k_matrix.a_matrix(initial_concentration) + + # done + return xr.DataArray(matrix, coords=((model_dimension, model_axis), ("clp_label", species))) + + def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + global_dimension = dataset_model.get_global_dimension() + name = "images" if global_dimension == "pixel" else "spectra" + + if "species" not in data.coords: + # We are the first Decay complex called and add SAD for all decay megacomplexes + retrieve_species_associated_data(dataset_model, data, global_dimension, name) + if isinstance(dataset_model.irf, IrfMultiGaussian) and "irf" not in data: + retrieve_irf(dataset_model, data, global_dimension) + + multiple_complexes = ( + len([m for m in dataset_model.megacomplex if isinstance(m, DecayMegacomplex)]) > 1 + ) + retrieve_decay_associated_data( + self, dataset_model, data, global_dimension, name, multiple_complexes + ) diff --git a/glotaran/builtin/models/kinetic_image/initial_concentration.py b/glotaran/builtin/megacomplexes/decay/initial_concentration.py similarity index 74% rename from glotaran/builtin/models/kinetic_image/initial_concentration.py rename to glotaran/builtin/megacomplexes/decay/initial_concentration.py index 1d6d952f4..43f18006e 100644 --- a/glotaran/builtin/models/kinetic_image/initial_concentration.py +++ b/glotaran/builtin/megacomplexes/decay/initial_concentration.py @@ -2,19 +2,19 @@ from __future__ import annotations import copy -import typing +from typing import List import numpy as np -from glotaran.model import model_attribute +from glotaran.model import model_item from glotaran.parameter import Parameter -@model_attribute( +@model_item( properties={ - "compartments": typing.List[str], - "parameters": typing.List[Parameter], - "exclude_from_normalize": {"type": typing.List[str], "default": []}, + "compartments": List[str], + "parameters": List[Parameter], + "exclude_from_normalize": {"type": List[str], "default": []}, } ) class InitialConcentration: diff --git a/glotaran/builtin/models/kinetic_image/irf.py b/glotaran/builtin/megacomplexes/decay/irf.py similarity index 53% rename from glotaran/builtin/models/kinetic_image/irf.py rename to glotaran/builtin/megacomplexes/decay/irf.py index febf38758..e1297c74f 100644 --- a/glotaran/builtin/models/kinetic_image/irf.py +++ b/glotaran/builtin/megacomplexes/decay/irf.py @@ -6,17 +6,17 @@ import numpy as np from glotaran.model import ModelError -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed +from glotaran.model import model_item +from glotaran.model import model_item_typed from glotaran.parameter import Parameter -@model_attribute(has_type=True) +@model_item(has_type=True) class IrfMeasured: """A measured IRF. The data must be supplied by the dataset.""" -@model_attribute( +@model_item( properties={ "center": List[Parameter], "width": List[Parameter], @@ -59,6 +59,7 @@ class IrfMultiGaussian: def parameter( self, global_index: int, global_axis: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, bool, float]: + """Returns the properties of the irf with shift applied.""" centers = self.center if isinstance(self.center, list) else [self.center] centers = np.asarray([c.value for c in centers]) @@ -76,10 +77,8 @@ def parameter( ) if len_centers == 1: centers = [centers[0] for _ in range(len_widths)] - len_centers = len_widths else: widths = [widths[0] for _ in range(len_centers)] - len_widths = len_centers scales = self.scale if self.scale is not None else [1.0 for _ in centers] scales = scales if isinstance(scales, list) else [scales] @@ -108,7 +107,7 @@ def calculate(self, index: int, global_axis: np.ndarray, model_axis: np.ndarray) ) -@model_attribute( +@model_item( properties={ "center": Parameter, "width": Parameter, @@ -119,11 +118,97 @@ class IrfGaussian(IrfMultiGaussian): pass -@model_attribute_typed( +@model_item( + properties={ + "dispersion_center": {"type": Parameter, "allow_none": True}, + "center_dispersion": {"type": List[Parameter], "default": []}, + "width_dispersion": {"type": List[Parameter], "default": []}, + "model_dispersion_with_wavenumber": {"type": bool, "default": False}, + }, + has_type=True, +) +class IrfSpectralMultiGaussian(IrfMultiGaussian): + """ + Represents a gaussian IRF. + + One width and one center is a single gauss. + + One center and multiple widths is a multiple gaussian. + + Multiple center and multiple widths is Double-, Triple- , etc. Gaussian. + + Parameters + ---------- + + label: + label of the irf + center: + one or more center of the irf as parameter indices + width: + one or more widths of the gaussian as parameter index + center_dispersion: + polynomial coefficients for the dispersion of the + center as list of parameter indices. None for no dispersion. + width_dispersion: + polynomial coefficients for the dispersion of the + width as parameter indices. None for no dispersion. + + """ + + def parameter(self, global_index: int, global_axis: np.ndarray): + """Returns the properties of the irf with shift and dispersion applied.""" + centers, widths, scale, shift, backsweep, backsweep_period = super().parameter( + global_index, global_axis + ) + + index = global_axis[global_index] if global_index is not None else None + + if self.dispersion_center is not None: + dist = ( + (1e3 / index - 1e3 / self.dispersion_center) + if self.model_dispersion_with_wavenumber + else (index - self.dispersion_center) / 100 + ) + + if len(self.center_dispersion) != 0: + if self.dispersion_center is None: + raise ModelError(f"No dispersion center defined for irf '{self.label}'") + for i, disp in enumerate(self.center_dispersion): + centers += disp * np.power(dist, i + 1) + + if len(self.width_dispersion) != 0: + if self.dispersion_center is None: + raise ModelError(f"No dispersion center defined for irf '{self.label}'") + for i, disp in enumerate(self.width_dispersion): + widths = widths + disp * np.power(dist, i + 1) + + return centers, widths, scale, shift, backsweep, backsweep_period + + def calculate_dispersion(self, axis): + dispersion = [] + for index, _ in enumerate(axis): + center, _, _, _, _, _ = self.parameter(index, axis) + dispersion.append(center) + return np.asarray(dispersion).T + + +@model_item( + properties={ + "center": Parameter, + "width": Parameter, + }, + has_type=True, +) +class IrfSpectralGaussian(IrfSpectralMultiGaussian): + pass + + +@model_item_typed( types={ "gaussian": IrfGaussian, "multi-gaussian": IrfMultiGaussian, - "measured": IrfMeasured, + "spectral-multi-gaussian": IrfSpectralMultiGaussian, + "spectral-gaussian": IrfSpectralGaussian, } ) class Irf: diff --git a/glotaran/builtin/models/kinetic_image/k_matrix.py b/glotaran/builtin/megacomplexes/decay/k_matrix.py similarity index 96% rename from glotaran/builtin/models/kinetic_image/k_matrix.py rename to glotaran/builtin/megacomplexes/decay/k_matrix.py index e30280ce0..71a4cda27 100644 --- a/glotaran/builtin/models/kinetic_image/k_matrix.py +++ b/glotaran/builtin/megacomplexes/decay/k_matrix.py @@ -6,15 +6,16 @@ from collections import OrderedDict import numpy as np -import scipy +from scipy.linalg import eig +from scipy.linalg import solve -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.model import model_attribute +from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration +from glotaran.model import model_item from glotaran.parameter import Parameter from glotaran.utils.ipython import MarkdownStr -@model_attribute( +@model_item( properties={ "matrix": {"type": typing.Dict[typing.Tuple[str, str], Parameter]}, }, @@ -202,7 +203,7 @@ def eigen(self, compartments: list[str]) -> tuple[np.ndarray, np.ndarray]: matrix = self.full(compartments).T # get the eigenvectors and values, we take the left ones to have # computation consistent with TIMP - eigenvalues, eigenvectors = scipy.linalg.eig(matrix, left=True, right=False) + eigenvalues, eigenvectors = eig(matrix, left=True, right=False) return (eigenvalues.real, eigenvectors.real) def rates(self, initial_concentration: InitialConcentration) -> np.ndarray: @@ -231,7 +232,7 @@ def _gamma( for c in compartments ] - gamma = scipy.linalg.solve(eigenvectors, initial_concentration) + gamma = solve(eigenvectors, initial_concentration) return np.diag(gamma) def a_matrix(self, initial_concentration: InitialConcentration) -> np.ndarray: diff --git a/glotaran/builtin/models/kinetic_image/test/test_kinetic_image_model.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py similarity index 78% rename from glotaran/builtin/models/kinetic_image/test/test_kinetic_image_model.py rename to glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index 4e0ecc00c..c89514984 100644 --- a/glotaran/builtin/models/kinetic_image/test/test_kinetic_image_model.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -4,7 +4,8 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate -from glotaran.builtin.models.kinetic_image import KineticImageModel +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -15,12 +16,23 @@ def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): amplitudes[i] * np.exp(-np.log(2) * np.square(2 * (axis - centers[i]) / widths[i])) for i, _ in enumerate(labels) ], - coords=[("clp_label", labels), ("pixel", axis)], + coords=[("clp_label", labels), ("pixel", axis.data)], ).T +class DecayModel(Model): + @classmethod + def from_dict(cls, model_dict): + return super().from_dict( + model_dict, + megacomplex_types={ + "decay": DecayMegacomplex, + }, + ) + + class OneComponentOneChannel: - model = KineticImageModel.from_dict( + model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, @@ -51,14 +63,15 @@ class OneComponentOneChannel: [101e-3, [1, {"vary": False, "non-negative": False}]] ) - time = np.asarray(np.arange(0, 50, 1.5)) - axis = {"time": time, "pixel": np.asarray([0])} + time = xr.DataArray(np.arange(0, 50, 1.5)) + pixel = xr.DataArray([0]) + axis = {"time": time, "pixel": pixel} - clp = xr.DataArray([[1]], coords=[("pixel", [0]), ("clp_label", ["s1"])]) + clp = xr.DataArray([[1]], coords=[("pixel", pixel.data), ("clp_label", ["s1"])]) class OneComponentOneChannelGaussianIrf: - model = KineticImageModel.from_dict( + model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["5"]}, @@ -85,11 +98,13 @@ class OneComponentOneChannelGaussianIrf: }, } ) - assert model.overwrite_index_dependent() initial_parameters = ParameterGroup.from_list( [101e-4, 0.1, 1, [0.1, {"vary": False}], [1, {"vary": False, "non-negative": False}]] ) + assert model.megacomplex["mc1"].index_dependent( + model.dataset["dataset1"].fill(model, initial_parameters) + ) wanted_parameters = ParameterGroup.from_list( [ [101e-3, {"non-negative": True}], @@ -100,62 +115,15 @@ class OneComponentOneChannelGaussianIrf: ] ) - time = np.asarray(np.arange(-10, 50, 1.5)) - axis = {"time": time, "pixel": np.asarray([0])} - clp = xr.DataArray([[1]], coords=[("pixel", [0]), ("clp_label", ["s1"])]) - - -class OneComponentOneChannelMeasuredIrf: - model = KineticImageModel.from_dict( - { - "initial_concentration": { - "j1": {"compartments": ["s1"], "parameters": ["2"]}, - }, - "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s1", "s1"): "1", - } - } - }, - "irf": { - "irf1": {"type": "measured"}, - }, - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "irf": "irf1", - "megacomplex": ["mc1"], - }, - }, - } - ) - - initial_parameters = ParameterGroup.from_list( - [101e-4, [1, {"vary": False, "non-negative": False}]] - ) - wanted_parameters = ParameterGroup.from_list( - [101e-3, [1, {"vary": False, "non-negative": False}]] - ) - - time = np.asarray(np.arange(-10, 50, 1.5)) - axis = {"time": time, "pixel": np.asarray([0])} - - center = 0 - width = 5 - irf = (1 / np.sqrt(2 * np.pi)) * np.exp( - -(time - center) * (time - center) / (2 * width * width) - ) - model.irf["irf1"].irfdata = irf + time = xr.DataArray(np.arange(0, 50, 1.5)) + pixel = xr.DataArray([0]) + axis = {"time": time, "pixel": pixel} - clp = xr.DataArray([[1]], coords=[("pixel", [0]), ("clp_label", ["s1"])]) + clp = xr.DataArray([[1]], coords=[("pixel", pixel.data), ("clp_label", ["s1"])]) class ThreeComponentParallel: - model = KineticImageModel.from_dict( + model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["j.1", "j.1", "j.1"]}, @@ -211,9 +179,9 @@ class ThreeComponentParallel: "j": [["1", 1, {"vary": False, "non-negative": False}]], } ) + time = xr.DataArray(np.arange(-10, 100, 1.5)) + pixel = xr.DataArray(np.arange(600, 750, 10)) - time = np.arange(-10, 100, 1.5) - pixel = np.arange(600, 750, 10) axis = {"time": time, "pixel": pixel} clp = _create_gaussian_clp( @@ -222,7 +190,7 @@ class ThreeComponentParallel: class ThreeComponentSequential: - model = KineticImageModel.from_dict( + model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["j.1", "j.0", "j.0"]}, @@ -286,8 +254,8 @@ class ThreeComponentSequential: } ) - time = np.asarray(np.arange(-10, 50, 1.0)) - pixel = np.arange(600, 750, 10) + time = xr.DataArray(np.arange(-10, 50, 1.0)) + pixel = xr.DataArray(np.arange(600, 750, 10)) axis = {"time": time, "pixel": pixel} clp = _create_gaussian_clp( @@ -300,7 +268,6 @@ class ThreeComponentSequential: [ OneComponentOneChannel, OneComponentOneChannelGaussianIrf, - # OneComponentOneChannelMeasuredIrf, ThreeComponentParallel, ThreeComponentSequential, ], diff --git a/glotaran/builtin/models/kinetic_image/test/test_k_matrix.py b/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py similarity index 93% rename from glotaran/builtin/models/kinetic_image/test/test_k_matrix.py rename to glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py index f67da027c..08a85cf78 100644 --- a/glotaran/builtin/models/kinetic_image/test/test_k_matrix.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py @@ -2,8 +2,8 @@ import pytest from IPython.core.formatters import format_display_data -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.builtin.models.kinetic_image.k_matrix import KMatrix +from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration +from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix from glotaran.parameter import ParameterGroup @@ -316,10 +316,11 @@ def test_kmatrix_ipython_rendering(): rendered_obj = format_display_data(kmatrix)[0] - assert "text/markdown" in rendered_obj - assert rendered_obj["text/markdown"].startswith("| compartment") + test_markdown_str = "text/markdown" + assert test_markdown_str in rendered_obj + assert rendered_obj[test_markdown_str].startswith("| compartment") rendered_markdown_return = format_display_data(kmatrix.matrix_as_markdown())[0] - assert "text/markdown" in rendered_markdown_return - assert rendered_markdown_return["text/markdown"].startswith("| compartment") + assert test_markdown_str in rendered_markdown_return + assert rendered_markdown_return[test_markdown_str].startswith("| compartment") diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_irf.py b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py similarity index 88% rename from glotaran/builtin/models/kinetic_spectrum/test/test_spectral_irf.py rename to glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py index f04ba0614..e2ca08519 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_irf.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py @@ -1,5 +1,8 @@ +from copy import deepcopy + import numpy as np import pytest +import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate @@ -8,14 +11,12 @@ from glotaran.project import Scheme MODEL_BASE = """\ -type: kinetic-spectrum +default-megacomplex: decay dataset: dataset1: megacomplex: [mc1] initial_concentration: j1 irf: irf1 - shape: - s1: sh1 initial_concentration: j1: compartments: [s1] @@ -23,6 +24,10 @@ megacomplex: mc1: k_matrix: [k1] + mc2: + type: spectral + shape: + s1: sh1 k_matrix: k1: matrix: @@ -87,16 +92,16 @@ class SimpleIrfDispersion: time_p1 = np.linspace(-1, 2, 50, endpoint=False) time_p2 = np.linspace(2, 5, 30, endpoint=False) time_p3 = np.geomspace(5, 10, num=20) - time = np.concatenate([time_p1, time_p2, time_p3]) - spectral = np.arange(300, 500, 100) + time = xr.DataArray(np.concatenate([time_p1, time_p2, time_p3])) + spectral = xr.DataArray(np.arange(300, 500, 100)) axis = {"time": time, "spectral": spectral} class MultiIrfDispersion: model = load_model(MODEL_MULTI_IRF_DISPERSION, format_name="yml_str") parameters = load_parameters(PARAMETERS_MULTI_IRF_DISPERSION, format_name="yml_str") - time = np.arange(-1, 5, 0.2) - spectral = np.arange(300, 500, 100) + time = xr.DataArray(np.arange(-1, 5, 0.2)) + spectral = xr.DataArray(np.arange(300, 500, 100)) axis = {"time": time, "spectral": spectral} @@ -117,7 +122,9 @@ def test_spectral_irf(suite): print(model.validate(parameters)) assert model.valid(parameters) - dataset = simulate(model, "dataset1", parameters, suite.axis) + sim_model = deepcopy(model) + sim_model.dataset["dataset1"].global_megacomplex = ["mc2"] + dataset = simulate(sim_model, "dataset1", parameters, suite.axis) assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) @@ -137,7 +144,7 @@ def test_spectral_irf(suite): resultdata = result.data["dataset1"] - # print(resultdata) + print(resultdata) assert np.array_equal(dataset["time"], resultdata["time"]) assert np.array_equal(dataset["spectral"], resultdata["spectral"]) @@ -153,3 +160,4 @@ def test_spectral_irf(suite): assert "species_associated_spectra" in resultdata assert "decay_associated_spectra" in resultdata + assert "irf_center" in resultdata diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py new file mode 100644 index 000000000..7c22c768a --- /dev/null +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numba as nb +import numpy as np +import xarray as xr + +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian +from glotaran.model import DatasetModel + +if TYPE_CHECKING: + from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex + + +def decay_matrix_implementation( + matrix: np.ndarray, + rates: np.ndarray, + global_index: int, + global_axis: np.ndarray, + model_axis: np.ndarray, + dataset_model: DatasetModel, +): + if isinstance(dataset_model.irf, IrfMultiGaussian): + + ( + centers, + widths, + irf_scales, + shift, + backsweep, + backsweep_period, + ) = dataset_model.irf.parameter(global_index, global_axis) + + for center, width, irf_scale in zip(centers, widths, irf_scales): + calculate_decay_matrix_gaussian_irf( + matrix, + rates, + model_axis, + center - shift, + width, + irf_scale, + backsweep, + backsweep_period, + ) + if dataset_model.irf.normalize: + matrix /= np.sum(irf_scale) + + else: + calculate_decay_matrix_no_irf(matrix, rates, model_axis) + + +@nb.jit(nopython=True, parallel=True) +def calculate_decay_matrix_no_irf(matrix, rates, times): + for n_r in nb.prange(rates.size): + r_n = rates[n_r] + for n_t in range(times.size): + t_n = times[n_t] + matrix[n_t, n_r] += np.exp(r_n * t_n) + + +sqrt2 = np.sqrt(2) + + +@nb.jit(nopython=True, parallel=True) +def calculate_decay_matrix_gaussian_irf( + matrix, rates, times, center, width, scale, backsweep, backsweep_period +): + """Calculates a decay matrix with a gaussian irf.""" + for n_r in nb.prange(rates.size): + r_n = -rates[n_r] + backsweep_valid = abs(r_n) * backsweep_period > 0.001 + alpha = (r_n * width) / sqrt2 + for n_t in nb.prange(times.size): + t_n = times[n_t] + beta = (t_n - center) / (width * sqrt2) + thresh = beta - alpha + if thresh < -1: + matrix[n_t, n_r] += scale * 0.5 * erfcx(-thresh) * np.exp(-beta * beta) + else: + matrix[n_t, n_r] += ( + scale * 0.5 * (1 + erf(thresh)) * np.exp(alpha * (alpha - 2 * beta)) + ) + if backsweep and backsweep_valid: + x1 = np.exp(-r_n * (t_n - center + backsweep_period)) + x2 = np.exp(-r_n * ((backsweep_period / 2) - (t_n - center))) + x3 = np.exp(-r_n * backsweep_period) + matrix[n_t, n_r] += scale * (x1 + x2) / (1 - x3) + + +import ctypes # noqa: E402 + +# This is a work around to use scipy.special function with numba +from numba.extending import get_cython_function_address # noqa: E402 + +_dble = ctypes.c_double + +functype = ctypes.CFUNCTYPE(_dble, _dble) + +erf_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erf") +erfcx_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erfcx") + +erf = functype(erf_addr) +erfcx = functype(erfcx_addr) + + +def retrieve_species_associated_data( + dataset_model: DatasetModel, data: xr.Dataset, global_dimension: str, name: str +): + species = dataset_model.initial_concentration.compartments + model_dimension = dataset_model.get_model_dimension() + + data.coords["species"] = species + data[f"species_associated_{name}"] = ( + ( + global_dimension, + "species", + ), + data.clp.sel(clp_label=species).data, + ) + + if len(data.matrix.shape) == 3: + # index dependent + data["species_concentration"] = ( + ( + global_dimension, + model_dimension, + "species", + ), + data.matrix.sel(clp_label=species).values, + ) + else: + # index independent + data["species_concentration"] = ( + ( + model_dimension, + "species", + ), + data.matrix.sel(clp_label=species).values, + ) + + +def retrieve_decay_associated_data( + megacomplex: DecayMegacomplex, + dataset_model: DatasetModel, + data: xr.Dataset, + global_dimension: str, + name: str, + multiple_complexes: bool, +): + k_matrix = megacomplex.full_k_matrix() + + species = dataset_model.initial_concentration.compartments + species = [c for c in species if c in k_matrix.involved_compartments()] + + matrix = k_matrix.full(species) + matrix_reduced = k_matrix.reduced(species) + a_matrix = k_matrix.a_matrix(dataset_model.initial_concentration) + rates = k_matrix.rates(dataset_model.initial_concentration) + lifetimes = 1 / rates + + das = data[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T + + component_coords = {"rate": ("component", rates), "lifetime": ("component", lifetimes)} + das_coords = component_coords.copy() + das_coords[global_dimension] = data.coords[global_dimension] + das_name = f"decay_associated_{name}" + das = xr.DataArray(das, dims=(global_dimension, "component"), coords=das_coords) + + a_matrix_coords = component_coords.copy() + a_matrix_coords["species"] = species + a_matrix_name = "a_matrix" + a_matrix = xr.DataArray(a_matrix, coords=a_matrix_coords, dims=("component", "species")) + + k_matrix_name = "k_matrix" + k_matrix = xr.DataArray(matrix, coords=[("to_species", species), ("from_species", species)]) + + k_matrix_reduced_name = "k_matrix_reduced" + k_matrix_reduced = xr.DataArray( + matrix_reduced, coords=[("to_species", species), ("from_species", species)] + ) + + if multiple_complexes: + das_name = f"decay_associated_{name}_{megacomplex.label}" + das = das.rename(component=f"component_{megacomplex.label}") + a_matrix_name = f"a_matrix_{megacomplex.label}" + a_matrix = a_matrix.rename(component=f"component_{megacomplex.label}") + k_matrix_name = f"k_matrix_{megacomplex.label}" + k_matrix_reduced_name = f"k_matrix_reduced_{megacomplex.label}" + + data[das_name] = das + data[a_matrix_name] = a_matrix + data[k_matrix_name] = k_matrix + data[k_matrix_reduced_name] = k_matrix_reduced + + +def retrieve_irf(dataset_model: DatasetModel, data: xr.Dataset, global_dimension: str): + + irf = dataset_model.irf + model_dimension = dataset_model.get_model_dimension() + + data["irf"] = ( + (model_dimension), + irf.calculate( + index=0, + global_axis=data.coords[global_dimension].values, + model_axis=data.coords[model_dimension].values, + ).data, + ) + center = irf.center if isinstance(irf.center, list) else [irf.center] + width = irf.width if isinstance(irf.width, list) else [irf.width] + data["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] + data["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] + if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: + for i, dispersion in enumerate(irf.calculate_dispersion(data.coords["spectral"].values)): + data[f"center_dispersion_{i+1}"] = ( + global_dimension, + dispersion, + ) diff --git a/glotaran/builtin/megacomplexes/spectral/__init__.py b/glotaran/builtin/megacomplexes/spectral/__init__.py new file mode 100644 index 000000000..13092c602 --- /dev/null +++ b/glotaran/builtin/megacomplexes/spectral/__init__.py @@ -0,0 +1 @@ +from glotaran.builtin.megacomplexes.spectral.spectral_megacomplex import SpectralMegacomplex diff --git a/glotaran/builtin/models/spectral/shape.py b/glotaran/builtin/megacomplexes/spectral/shape.py similarity index 93% rename from glotaran/builtin/models/spectral/shape.py rename to glotaran/builtin/megacomplexes/spectral/shape.py index 5aa80d7b2..7823590b0 100644 --- a/glotaran/builtin/models/spectral/shape.py +++ b/glotaran/builtin/megacomplexes/spectral/shape.py @@ -2,12 +2,12 @@ import numpy as np -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed +from glotaran.model import model_item +from glotaran.model import model_item_typed from glotaran.parameter import Parameter -@model_attribute( +@model_item( properties={ "amplitude": Parameter, "location": Parameter, @@ -92,7 +92,7 @@ def calculate_gaussian(self, axis: np.ndarray) -> np.ndarray: An array representing a Gaussian shape. """ return self.amplitude * np.exp( - -np.log(2) * np.square(2 * ((1e7 / axis) - self.location) / self.width) + -np.log(2) * np.square(2 * (axis - self.location) / self.width) ) def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: @@ -147,7 +147,7 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: np.ndarray An array representing a skewed Gaussian shape. """ - log_args = 1 + (2 * self.skewness * ((1e7 / axis) - self.location) / self.width) + log_args = 1 + (2 * self.skewness * (axis - self.location) / self.width) result = np.zeros(log_args.shape) valid_arg_mask = np.where(log_args > 0) result[valid_arg_mask] = self.amplitude * np.exp( @@ -157,7 +157,7 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: return result -@model_attribute(properties={}, has_type=True) +@model_item(properties={}, has_type=True) class SpectralShapeOne: """A constant spectral shape with value 1""" @@ -177,7 +177,7 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: return np.ones(axis.shape[0]) -@model_attribute(properties={}, has_type=True) +@model_item(properties={}, has_type=True) class SpectralShapeZero: """A constant spectral shape with value 0""" @@ -199,7 +199,7 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: return np.zeros(axis.shape[0]) -@model_attribute_typed( +@model_item_typed( types={ "skewed-gaussian": SpectralShapeSkewedGaussian, "one": SpectralShapeOne, diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py new file mode 100644 index 000000000..508e66a49 --- /dev/null +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Dict + +import numpy as np +import xarray as xr + +from glotaran.builtin.megacomplexes.spectral.shape import SpectralShape +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import ModelError +from glotaran.model import megacomplex + + +@megacomplex( + dimension="spectral", + properties={"energy_spectrum": {"type": bool, "default": False}}, + model_items={ + "shape": Dict[str, SpectralShape], + }, + register_as="spectral", +) +class SpectralMegacomplex(Megacomplex): + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + + compartments = [] + for compartment in self.shape: + if compartment in compartments: + raise ModelError(f"More then one shape defined for compartment '{compartment}'") + compartments.append(compartment) + + model_dimension = dataset_model.get_model_dimension() + model_axis = dataset_model.get_coordinates()[model_dimension].data + if self.energy_spectrum: + model_axis = 1e7 / model_axis + + dim1 = model_axis.size + dim2 = len(self.shape) + matrix = np.zeros((dim1, dim2)) + + for i, shape in enumerate(self.shape.values()): + matrix[:, i] += shape.calculate(model_axis) + return xr.DataArray( + matrix, coords=((model_dimension, model_axis), ("clp_label", compartments)) + ) + + def index_dependent(self, dataset: DatasetModel) -> bool: + return False + + def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + if "species" in data.coords: + return + + species = [] + for megacomplex in dataset_model.megacomplex: # noqa F402 + if isinstance(megacomplex, SpectralMegacomplex): + species += [ + compartment for compartment in megacomplex.shape if compartment not in species + ] + + data.coords["species"] = species + data["species_spectra"] = ( + ( + dataset_model.get_model_dimension(), + "species", + ), + data.matrix.sel(clp_label=species).values, + ) + data["species_associated_concentrations"] = ( + ( + dataset_model.get_global_dimension(), + "species", + ), + data.clp.sel(clp_label=species).data, + ) diff --git a/glotaran/builtin/models/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py similarity index 81% rename from glotaran/builtin/models/spectral/test/test_spectral_model.py rename to glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index be9456c1f..2f2bbd5cc 100644 --- a/glotaran/builtin/models/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -5,14 +5,26 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate from glotaran.analysis.util import calculate_matrix -from glotaran.builtin.models.kinetic_image import KineticImageModel -from glotaran.builtin.models.spectral.spectral_model import SpectralModel +from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import DecayModel +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +class SpectralModel(Model): + @classmethod + def from_dict(cls, model_dict): + return super().from_dict( + model_dict, + megacomplex_types={ + "spectral": SpectralMegacomplex, + }, + ) + + class OneCompartmentModel: - kinetic_model = KineticImageModel.from_dict( + decay_model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, @@ -36,7 +48,7 @@ class OneCompartmentModel: } ) - kinetic_parameters = ParameterGroup.from_list( + decay_parameters = ParameterGroup.from_list( [101e-4, [1, {"vary": False, "non-negative": False}]] ) @@ -67,17 +79,15 @@ class OneCompartmentModel: spectral = xr.DataArray(np.arange(400, 600, 5)) axis = {"time": time, "spectral": spectral} - kinetic_dataset_model = kinetic_model.dataset["dataset1"].fill( - kinetic_model, kinetic_parameters - ) - kinetic_dataset_model.overwrite_global_dimension("spectral") - kinetic_dataset_model.set_coords(axis) - clp = calculate_matrix(kinetic_dataset_model, {}) - kinetic_compartments = clp.coords["clp_label"].values + decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model.overwrite_global_dimension("spectral") + decay_dataset_model.set_coordinates(axis) + clp = calculate_matrix(decay_dataset_model, {}) + decay_compartments = clp.coords["clp_label"].values class ThreeCompartmentModel: - kinetic_model = KineticImageModel.from_dict( + decay_model = DecayModel.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["4", "4", "4"]}, @@ -103,7 +113,7 @@ class ThreeCompartmentModel: } ) - kinetic_parameters = ParameterGroup.from_list( + decay_parameters = ParameterGroup.from_list( [101e-4, 101e-5, 101e-6, [1, {"vary": False, "non-negative": False}]] ) @@ -166,13 +176,11 @@ class ThreeCompartmentModel: spectral = xr.DataArray(np.arange(400, 600, 5)) axis = {"time": time, "spectral": spectral} - kinetic_dataset_model = kinetic_model.dataset["dataset1"].fill( - kinetic_model, kinetic_parameters - ) - kinetic_dataset_model.overwrite_global_dimension("spectral") - kinetic_dataset_model.set_coords(axis) - clp = calculate_matrix(kinetic_dataset_model, {}) - kinetic_compartments = clp.coords["clp_label"].values + decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model.overwrite_global_dimension("spectral") + decay_dataset_model.set_coordinates(axis) + clp = calculate_matrix(decay_dataset_model, {}) + decay_compartments = clp.coords["clp_label"].values @pytest.mark.parametrize( @@ -226,10 +234,10 @@ def test_spectral_model(suite): assert "species_associated_concentrations" in resultdata assert resultdata.species_associated_concentrations.shape == ( suite.axis["time"].size, - len(suite.kinetic_compartments), + len(suite.decay_compartments), ) assert "species_spectra" in resultdata assert resultdata.species_spectra.shape == ( suite.axis["spectral"].size, - len(suite.kinetic_compartments), + len(suite.decay_compartments), ) diff --git a/glotaran/builtin/models/__init__.py b/glotaran/builtin/models/__init__.py deleted file mode 100644 index 0d09b4f4c..000000000 --- a/glotaran/builtin/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Glotaran Models Package""" diff --git a/glotaran/builtin/models/kinetic_image/__init__.py b/glotaran/builtin/models/kinetic_image/__init__.py deleted file mode 100644 index 97dc65621..000000000 --- a/glotaran/builtin/models/kinetic_image/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from glotaran.builtin.models.kinetic_image.kinetic_image_model import KineticImageModel diff --git a/glotaran/builtin/models/kinetic_image/initial_concentration.pyi b/glotaran/builtin/models/kinetic_image/initial_concentration.pyi deleted file mode 100644 index c3990b1d3..000000000 --- a/glotaran/builtin/models/kinetic_image/initial_concentration.pyi +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from glotaran.model import DatasetDescriptor -from glotaran.model import model_attribute -from glotaran.parameter import Parameter - -class InitialConcentration: - @property - def compartments(self) -> list[str]: ... - @property - def parameters(self) -> list[Parameter]: ... - @property - def exclude_from_normalize(self) -> list[Parameter]: ... - def normalized(self, dataset: DatasetDescriptor) -> InitialConcentration: ... diff --git a/glotaran/builtin/models/kinetic_image/irf.pyi b/glotaran/builtin/models/kinetic_image/irf.pyi deleted file mode 100644 index 40fbc511a..000000000 --- a/glotaran/builtin/models/kinetic_image/irf.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed -from glotaran.parameter import Parameter - -class IrfMeasured: ... # noqa: E701 - -class IrfMultiGaussian: - @property - def center(self) -> list[Parameter]: ... - @property - def width(self) -> list[Parameter]: ... - @property - def scale(self) -> list[Parameter]: ... - @property - def backsweep_period(self) -> Parameter: ... - def parameter(self, index: Any): ... - def calculate(self, index: Any, axis: Any): ... - -class IrfGaussian(IrfMultiGaussian): - @property - def center(self) -> Parameter: ... - @property - def width(self) -> Parameter: ... - -class Irf: - @classmethod - def add_type(cls, type_name: str, attribute_type: type) -> None: ... diff --git a/glotaran/builtin/models/kinetic_image/k_matrix.pyi b/glotaran/builtin/models/kinetic_image/k_matrix.pyi deleted file mode 100644 index adc169cc3..000000000 --- a/glotaran/builtin/models/kinetic_image/k_matrix.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import numpy as np - -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.model import model_attribute -from glotaran.parameter import Parameter - -class KMatrix: - @classmethod - def empty(cls: Any, label: str, compartments: list[str]) -> KMatrix: ... - def involved_compartments(self) -> list[str]: ... - def combine(self, k_matrix: KMatrix) -> KMatrix: ... - def matrix_as_markdown( - self, compartments: list[str] = ..., fill_parameters: bool = ... - ) -> str: ... - def a_matrix_as_markdown(self, initial_concentration: InitialConcentration) -> str: ... - def reduced(self, compartments: list[str]) -> np.ndarray: ... - def full(self, compartments: list[str]) -> np.ndarray: ... - def eigen(self, compartments: list[str]) -> tuple[np.ndarray, np.ndarray]: ... - def rates(self, initial_concentration: InitialConcentration) -> np.ndarray: ... - def a_matrix(self, initial_concentration: InitialConcentration) -> np.ndarray: ... - def a_matrix_non_unibranch( - self, initial_concentration: InitialConcentration - ) -> np.ndarray: ... - def a_matrix_unibranch(self, initial_concentration: InitialConcentration) -> np.array: ... - def is_unibranched(self, initial_concentration: InitialConcentration) -> bool: ... - @property - def matrix(self) -> dict[tuple[str, str], Parameter]: ... diff --git a/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py b/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py deleted file mode 100644 index 4021cb0b5..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_baseline_megacomplex.py +++ /dev/null @@ -1,29 +0,0 @@ -"""This package contains the kinetic megacomplex item.""" -from __future__ import annotations - -import numpy as np -import xarray as xr - -from glotaran.model import DatasetDescriptor -from glotaran.model import Megacomplex -from glotaran.model import megacomplex - - -@megacomplex("time") -class KineticBaselineMegacomplex(Megacomplex): - def calculate_matrix( - self, - dataset_model: DatasetDescriptor, - indices: dict[str, int], - **kwargs, - ): - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coords()[model_dimension] - compartments = [f"{dataset_model.label}_baseline"] - matrix = np.ones((model_axis.size, 1), dtype=np.float64) - return xr.DataArray( - matrix, coords=((model_dimension, model_axis.data), ("clp_label", compartments)) - ) - - def index_dependent(self, dataset: DatasetDescriptor) -> bool: - return False diff --git a/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py b/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py deleted file mode 100644 index e9995df65..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_decay_megacomplex.py +++ /dev/null @@ -1,189 +0,0 @@ -"""This package contains the kinetic megacomplex item.""" -from __future__ import annotations - -from typing import List - -import numba as nb -import numpy as np -import xarray as xr - -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.model import DatasetDescriptor -from glotaran.model import Megacomplex -from glotaran.model import ModelError -from glotaran.model import megacomplex - - -@megacomplex( - "time", - properties={ - "k_matrix": List[str], - }, -) -class KineticDecayMegacomplex(Megacomplex): - """A Megacomplex with one or more K-Matrices.""" - - def has_k_matrix(self) -> bool: - return len(self.k_matrix) != 0 - - def full_k_matrix(self, model=None): - full_k_matrix = None - for k_matrix in self.k_matrix: - if model: - k_matrix = model.k_matrix[k_matrix] - if full_k_matrix is None: - full_k_matrix = k_matrix - # If multiple k matrices are present, we combine them - else: - full_k_matrix = full_k_matrix.combine(k_matrix) - return full_k_matrix - - @property - def involved_compartments(self): - return self.full_k_matrix().involved_compartments() if self.full_k_matrix() else [] - - def index_dependent(self, dataset: DatasetDescriptor) -> bool: - return False - - def calculate_matrix( - self, - dataset_model: DatasetDescriptor, - indices: dict[str, int], - **kwargs, - ): - if dataset_model.initial_concentration is None: - raise ModelError( - f'No initial concentration specified in dataset "{dataset_model.label}"' - ) - initial_concentration = dataset_model.initial_concentration.normalized() - - k_matrix = self.full_k_matrix() - - # we might have more compartments in the model then in the k matrix - compartments = [ - comp - for comp in initial_concentration.compartments - if comp in k_matrix.involved_compartments() - ] - - # the rates are the eigenvalues of the k matrix - rates = k_matrix.rates(initial_concentration) - - global_dimension = dataset_model.get_global_dimension() - global_index = indices.get(global_dimension) - global_axis = dataset_model.get_coords().get(global_dimension).values - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coords()[model_dimension].values - - # init the matrix - size = (model_axis.size, rates.size) - matrix = np.zeros(size, dtype=np.float64) - - kinetic_image_matrix_implementation( - matrix, rates, global_index, global_axis, model_axis, dataset_model - ) - - if not np.all(np.isfinite(matrix)): - raise ValueError( - f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" - f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" - ) - - # apply A matrix - matrix = matrix @ k_matrix.a_matrix(initial_concentration) - - # done - return xr.DataArray( - matrix, coords=((model_dimension, model_axis), ("clp_label", compartments)) - ) - - -def kinetic_image_matrix_implementation( - matrix: np.ndarray, - rates: np.ndarray, - global_index: int, - global_axis: np.ndarray, - model_axis: np.ndarray, - dataset_model: DatasetDescriptor, -): - if isinstance(dataset_model.irf, IrfMultiGaussian): - - ( - centers, - widths, - irf_scales, - shift, - backsweep, - backsweep_period, - ) = dataset_model.irf.parameter(global_index, global_axis) - - for center, width, irf_scale in zip(centers, widths, irf_scales): - calculate_kinetic_matrix_gaussian_irf( - matrix, - rates, - model_axis, - center - shift, - width, - irf_scale, - backsweep, - backsweep_period, - ) - if dataset_model.irf.normalize: - matrix /= np.sum(irf_scale) - - else: - calculate_kinetic_matrix_no_irf(matrix, rates, model_axis) - - -@nb.jit(nopython=True, parallel=True) -def calculate_kinetic_matrix_no_irf(matrix, rates, times): - for n_r in nb.prange(rates.size): - r_n = rates[n_r] - for n_t in range(times.size): - t_n = times[n_t] - matrix[n_t, n_r] += np.exp(r_n * t_n) - - -sqrt2 = np.sqrt(2) - - -@nb.jit(nopython=True, parallel=True) -def calculate_kinetic_matrix_gaussian_irf( - matrix, rates, times, center, width, scale, backsweep, backsweep_period -): - """Calculates a kinetic matrix with a gaussian irf.""" - for n_r in nb.prange(rates.size): - r_n = -rates[n_r] - backsweep_valid = abs(r_n) * backsweep_period > 0.001 - alpha = (r_n * width) / sqrt2 - for n_t in nb.prange(times.size): - t_n = times[n_t] - beta = (t_n - center) / (width * sqrt2) - thresh = beta - alpha - if thresh < -1: - matrix[n_t, n_r] += scale * 0.5 * erfcx(-thresh) * np.exp(-beta * beta) - else: - matrix[n_t, n_r] += ( - scale * 0.5 * (1 + erf(thresh)) * np.exp(alpha * (alpha - 2 * beta)) - ) - if backsweep and backsweep_valid: - x1 = np.exp(-r_n * (t_n - center + backsweep_period)) - x2 = np.exp(-r_n * ((backsweep_period / 2) - (t_n - center))) - x3 = np.exp(-r_n * backsweep_period) - matrix[n_t, n_r] += scale * (x1 + x2) / (1 - x3) - - -import ctypes # noqa: E402 - -# This is a work around to use scipy.special function with numba -from numba.extending import get_cython_function_address # noqa: E402 - -_dble = ctypes.c_double - -functype = ctypes.CFUNCTYPE(_dble, _dble) - -erf_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erf") -erfcx_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erfcx") - -erf = functype(erf_addr) -erfcx = functype(erfcx_addr) diff --git a/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.py b/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.py deleted file mode 100644 index 30aac6d90..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.py +++ /dev/null @@ -1,14 +0,0 @@ -""" Kinetic Image Dataset Descriptor""" - -from glotaran.model import DatasetDescriptor -from glotaran.model import model_attribute - - -@model_attribute( - properties={ - "initial_concentration": {"type": str, "allow_none": True}, - "irf": {"type": str, "allow_none": True}, - } -) -class KineticImageDatasetDescriptor(DatasetDescriptor): - pass diff --git a/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.pyi b/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.pyi deleted file mode 100644 index 55f3ee5a0..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_image_dataset_descriptor.pyi +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from glotaran.model import DatasetDescriptor -from glotaran.model import model_attribute - -class KineticImageDatasetDescriptor(DatasetDescriptor): - @property - def initial_concentration(self) -> str: ... - @property - def irf(self) -> str: ... - @property - def baseline(self) -> bool: ... - def get_k_matrices(self): ... - def compartments(self): ... diff --git a/glotaran/builtin/models/kinetic_image/kinetic_image_model.py b/glotaran/builtin/models/kinetic_image/kinetic_image_model.py deleted file mode 100644 index b013efdd9..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_image_model.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.builtin.models.kinetic_image.irf import Irf -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.builtin.models.kinetic_image.k_matrix import KMatrix -from glotaran.builtin.models.kinetic_image.kinetic_baseline_megacomplex import ( - KineticBaselineMegacomplex, -) -from glotaran.builtin.models.kinetic_image.kinetic_decay_megacomplex import KineticDecayMegacomplex -from glotaran.builtin.models.kinetic_image.kinetic_image_dataset_descriptor import ( - KineticImageDatasetDescriptor, -) -from glotaran.builtin.models.kinetic_image.kinetic_image_result import ( - finalize_kinetic_image_result, -) -from glotaran.model import Model -from glotaran.model import model - - -def index_dependent(model: KineticImageModel) -> bool: - return any( - isinstance(irf, IrfMultiGaussian) and irf.shift is not None for irf in model.irf.values() - ) - - -@model( - "kinetic-image", - attributes={ - "initial_concentration": InitialConcentration, - "k_matrix": KMatrix, - "irf": Irf, - }, - dataset_type=KineticImageDatasetDescriptor, - default_megacomplex_type="kinetic-decay", - megacomplex_types={ - "kinetic-decay": KineticDecayMegacomplex, - "kinetic-baseline": KineticBaselineMegacomplex, - }, - model_dimension="time", - global_dimension="pixel", - grouped=False, - index_dependent=index_dependent, - finalize_data_function=finalize_kinetic_image_result, -) -class KineticImageModel(Model): - pass diff --git a/glotaran/builtin/models/kinetic_image/kinetic_image_model.pyi b/glotaran/builtin/models/kinetic_image/kinetic_image_model.pyi deleted file mode 100644 index 375e3730e..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_image_model.pyi +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Mapping - -import numpy as np - -from glotaran.builtin.models.kinetic_image.initial_concentration import InitialConcentration -from glotaran.builtin.models.kinetic_image.irf import Irf -from glotaran.builtin.models.kinetic_image.k_matrix import KMatrix -from glotaran.builtin.models.kinetic_image.kinetic_image_dataset_descriptor import ( - KineticImageDatasetDescriptor, -) -from glotaran.builtin.models.kinetic_image.kinetic_image_matrix import kinetic_image_matrix -from glotaran.builtin.models.kinetic_image.kinetic_image_megacomplex import KineticImageMegacomplex -from glotaran.builtin.models.kinetic_image.kinetic_image_result import ( - finalize_kinetic_image_result, -) -from glotaran.model import Model -from glotaran.model import model - -class KineticImageModel(Model): - dataset: Mapping[str, KineticImageDatasetDescriptor] - megacomplex: Mapping[str, KineticImageMegacomplex] - @staticmethod - def matrix( # type: ignore[override] - dataset_descriptor: KineticImageDatasetDescriptor = ..., axis=..., index=..., irf=... - ) -> tuple[None, None] | tuple[list[Any], np.ndarray]: ... - @property - def initial_concentration(self) -> Mapping[str, InitialConcentration]: ... - @property - def k_matrix(self) -> Mapping[str, KMatrix]: ... - @property - def irf(self) -> Mapping[str, Irf]: ... diff --git a/glotaran/builtin/models/kinetic_image/kinetic_image_result.py b/glotaran/builtin/models/kinetic_image/kinetic_image_result.py deleted file mode 100644 index 710c8093f..000000000 --- a/glotaran/builtin/models/kinetic_image/kinetic_image_result.py +++ /dev/null @@ -1,153 +0,0 @@ -from __future__ import annotations - -import xarray as xr - -from glotaran.analysis.problem import Problem -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.builtin.models.kinetic_image.kinetic_baseline_megacomplex import ( - KineticBaselineMegacomplex, -) -from glotaran.builtin.models.kinetic_image.kinetic_decay_megacomplex import KineticDecayMegacomplex - - -def finalize_kinetic_image_result(model, problem: Problem, data: dict[str, xr.Dataset]): - - for label, dataset in data.items(): - - dataset_model = problem.filled_dataset_descriptors[label] - - retrieve_species_associated_data(problem.model, dataset, dataset_model, "images") - retrieve_decay_associated_data(problem.model, dataset, dataset_model, "images") - - if any( - isinstance(megacomplex, KineticBaselineMegacomplex) - for megacomplex in dataset_model.megacomplex - ): - dataset["baseline"] = dataset.clp.sel(clp_label=f"{dataset_model.label}_baseline") - - retrieve_irf(problem.model, dataset, dataset_model, "images") - - -def retrieve_species_associated_data(model, dataset, dataset_model, name): - compartments = dataset_model.initial_concentration.compartments - global_dimension = dataset_model.get_global_dimension() - model_dimension = dataset_model.get_model_dimension() - - dataset.coords["species"] = compartments - dataset[f"species_associated_{name}"] = ( - ( - global_dimension, - "species", - ), - dataset.clp.sel(clp_label=compartments).data, - ) - - if len(dataset.matrix.shape) == 3: - # index dependent - dataset["species_concentration"] = ( - ( - global_dimension, - model_dimension, - "species", - ), - dataset.matrix.sel(clp_label=compartments).values, - ) - else: - # index independent - dataset["species_concentration"] = ( - ( - model_dimension, - "species", - ), - dataset.matrix.sel(clp_label=compartments).values, - ) - - -def retrieve_decay_associated_data(model, dataset, dataset_model, name): - # get_das - all_das = [] - all_a_matrix = [] - all_k_matrix = [] - all_k_matrix_reduced = [] - all_das_labels = [] - - global_dimension = dataset_model.get_global_dimension() - - for megacomplex in dataset_model.megacomplex: - - if isinstance(megacomplex, KineticDecayMegacomplex): - k_matrix = megacomplex.full_k_matrix() - - compartments = dataset_model.initial_concentration.compartments - compartments = [c for c in compartments if c in k_matrix.involved_compartments()] - - matrix = k_matrix.full(compartments) - matrix_reduced = k_matrix.reduced(compartments) - a_matrix = k_matrix.a_matrix(dataset_model.initial_concentration) - rates = k_matrix.rates(dataset_model.initial_concentration) - lifetimes = 1 / rates - - das = ( - dataset[f"species_associated_{name}"].sel(species=compartments).values @ a_matrix.T - ) - - component_coords = {"rate": ("component", rates), "lifetime": ("component", lifetimes)} - - das_coords = component_coords.copy() - das_coords[global_dimension] = dataset.coords[global_dimension] - all_das_labels.append(megacomplex.label) - all_das.append( - xr.DataArray(das, dims=(global_dimension, "component"), coords=das_coords) - ) - a_matrix_coords = component_coords.copy() - a_matrix_coords["species"] = compartments - all_a_matrix.append( - xr.DataArray(a_matrix, coords=a_matrix_coords, dims=("component", "species")) - ) - all_k_matrix.append( - xr.DataArray( - matrix, coords=[("to_species", compartments), ("from_species", compartments)] - ) - ) - - all_k_matrix_reduced.append( - xr.DataArray( - matrix_reduced, - coords=[("to_species", compartments), ("from_species", compartments)], - ) - ) - - if all_das: - if len(all_das) == 1: - dataset[f"decay_associated_{name}"] = all_das[0] - dataset["a_matrix"] = all_a_matrix[0] - dataset["k_matrix"] = all_k_matrix[0] - dataset["k_matrix_reduced"] = all_k_matrix_reduced[0] - - else: - for i, das_label in enumerate(all_das_labels): - dataset[f"decay_associated_{name}_{das_label}"] = all_das[i].rename( - component=f"component_{das_label}" - ) - dataset[f"a_matrix_{das_label}"] = all_a_matrix[i].rename( - component=f"component_{das_label}" - ) - dataset[f"k_matrix_{das_label}"] = all_k_matrix[i] - dataset[f"k_matrix_reduced_{das_label}"] = all_k_matrix_reduced[i] - - -def retrieve_irf(model, dataset, dataset_model, name): - - irf = dataset_model.irf - global_dimension = dataset_model.get_global_dimension() - model_dimension = dataset_model.get_model_dimension() - - if isinstance(irf, IrfMultiGaussian): - dataset["irf"] = ( - (model_dimension), - irf.calculate( - index=0, - global_axis=dataset.coords[global_dimension], - model_axis=dataset.coords[model_dimension], - ).data, - ) diff --git a/glotaran/builtin/models/kinetic_spectrum/__init__.py b/glotaran/builtin/models/kinetic_spectrum/__init__.py deleted file mode 100644 index 99ce8db34..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import KineticSpectrumModel diff --git a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.py b/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.py deleted file mode 100644 index 8b0076231..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.py +++ /dev/null @@ -1,15 +0,0 @@ -import typing - -from glotaran.builtin.models.kinetic_image.kinetic_image_dataset_descriptor import ( - KineticImageDatasetDescriptor, -) -from glotaran.model import model_attribute - - -@model_attribute( - properties={ - "shape": {"type": typing.Dict[str, str], "allow_none": True}, - } -) -class KineticSpectrumDatasetDescriptor(KineticImageDatasetDescriptor): - pass diff --git a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.pyi b/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.pyi deleted file mode 100644 index 2586cf813..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_dataset_descriptor.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from glotaran.builtin.models.kinetic_image.kinetic_image_dataset_descriptor import ( - KineticImageDatasetDescriptor, -) -from glotaran.model import model_attribute - -class KineticSpectrumDatasetDescriptor(KineticImageDatasetDescriptor): - @property - def shape(self) -> dict[str, str]: ... diff --git a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.py b/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.py deleted file mode 100644 index 9ae414746..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np -import xarray as xr - -from glotaran.builtin.models.kinetic_image.kinetic_baseline_megacomplex import ( - KineticBaselineMegacomplex, -) -from glotaran.builtin.models.kinetic_image.kinetic_decay_megacomplex import KineticDecayMegacomplex -from glotaran.builtin.models.kinetic_image.kinetic_image_model import KineticImageModel -from glotaran.builtin.models.kinetic_spectrum.coherent_artifact_megacomplex import ( - CoherentArtifactMegacomplex, -) -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_dataset_descriptor import ( - KineticSpectrumDatasetDescriptor, -) -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_result import ( - finalize_kinetic_spectrum_result, -) -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import SpectralConstraint -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import ( - apply_spectral_constraints, -) -from glotaran.builtin.models.kinetic_spectrum.spectral_irf import IrfSpectralMultiGaussian -from glotaran.builtin.models.kinetic_spectrum.spectral_matrix import spectral_matrix -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import EqualAreaPenalty -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import apply_spectral_penalties -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import has_spectral_penalties -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import SpectralRelation -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import apply_spectral_relations -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import retrieve_related_clps -from glotaran.builtin.models.kinetic_spectrum.spectral_shape import SpectralShape -from glotaran.model import model - -if TYPE_CHECKING: - from glotaran.parameter import ParameterGroup - - -def has_kinetic_model_constraints(model: KineticSpectrumModel) -> bool: - return len(model.spectral_relations) + len(model.spectral_constraints) != 0 - - -def apply_kinetic_model_constraints( - model: KineticSpectrumModel, - dataset: str, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: - clp_labels, matrix = apply_spectral_relations( - model, dataset, parameters, clp_labels, matrix, index - ) - clp_labels, matrix = apply_spectral_constraints(model, clp_labels, matrix, index) - return clp_labels, matrix - - -def retrieve_spectral_clps( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - reduced_clp_labels: dict[str, list[str] | list[list[str]]], - reduced_clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], -) -> dict[str, list[np.ndarray]]: - if not has_kinetic_model_constraints(model): - return reduced_clps - - # Note: we are always in index_dependent case when we have constraints - clps = {} - for label in clp_labels: - clps[label] = [] - for i, index_reduced_clp_labels in enumerate(reduced_clp_labels[label]): - index_clp_labels = clp_labels[label][i] - index_reduced_clps = reduced_clps[label][i] - index_clps = np.zeros((len(index_clp_labels)), dtype=np.float64) - for j, clp_label in enumerate(index_reduced_clp_labels): - index_clps[index_clp_labels.index(clp_label)] = index_reduced_clps[j] - clps[label].append(index_clps) - clps = retrieve_related_clps(model, parameters, clp_labels, clps, data) - return clps - - -def index_dependent(model: KineticSpectrumModel) -> bool: - return ( - any( - isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center is not None - for irf in model.irf.values() - ) - or len(model.spectral_relations) != 0 - or len(model.spectral_constraints) != 0 - or len(model.weights) != 0 - ) - - -def grouped(model: KineticSpectrumModel): - return len(model.dataset) != 1 - - -@model( - "kinetic-spectrum", - attributes={ - "equal_area_penalties": EqualAreaPenalty, - "shape": SpectralShape, - "spectral_constraints": SpectralConstraint, - "spectral_relations": SpectralRelation, - }, - dataset_type=KineticSpectrumDatasetDescriptor, - default_megacomplex_type="kinetic-decay", - megacomplex_types={ - "coherent-artifact": CoherentArtifactMegacomplex, - "kinetic-decay": KineticDecayMegacomplex, - "kinetic-baseline": KineticBaselineMegacomplex, - }, - model_dimension="time", - global_matrix=spectral_matrix, - global_dimension="spectral", - has_matrix_constraints_function=has_kinetic_model_constraints, - constrain_matrix_function=apply_kinetic_model_constraints, - retrieve_clp_function=retrieve_spectral_clps, - has_additional_penalty_function=has_spectral_penalties, - additional_penalty_function=apply_spectral_penalties, - grouped=grouped, - index_dependent=index_dependent, - finalize_data_function=finalize_kinetic_spectrum_result, -) -class KineticSpectrumModel(KineticImageModel): - pass diff --git a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.pyi b/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.pyi deleted file mode 100644 index 9fa9da1a6..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_model.pyi +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Mapping - -import numpy as np - -from glotaran.builtin.models.kinetic_image.kinetic_image_megacomplex import KineticImageMegacomplex -from glotaran.builtin.models.kinetic_image.kinetic_image_model import KineticImageModel -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_dataset_descriptor import ( - KineticSpectrumDatasetDescriptor, -) -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_matrix import ( - kinetic_spectrum_matrix, -) -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_result import ( - finalize_kinetic_spectrum_result, -) -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import SpectralConstraint -from glotaran.builtin.models.kinetic_spectrum.spectral_constraints import ( - apply_spectral_constraints, -) -from glotaran.builtin.models.kinetic_spectrum.spectral_irf import IrfSpectralMultiGaussian -from glotaran.builtin.models.kinetic_spectrum.spectral_matrix import spectral_matrix -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import EqualAreaPenalty -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import apply_spectral_penalties -from glotaran.builtin.models.kinetic_spectrum.spectral_penalties import has_spectral_penalties -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import SpectralRelation -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import apply_spectral_relations -from glotaran.builtin.models.kinetic_spectrum.spectral_relations import retrieve_related_clps -from glotaran.builtin.models.kinetic_spectrum.spectral_shape import SpectralShape -from glotaran.model import model -from glotaran.parameter import ParameterGroup - -def has_kinetic_model_constraints(model: KineticSpectrumModel) -> bool: ... # noqa: F811 -def apply_kinetic_model_constraints( - model: KineticSpectrumModel, # noqa: F811 - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> Any: ... -def retrieve_spectral_clps( - model: KineticSpectrumModel, # noqa: F811 - parameters: ParameterGroup, - clp_labels: list[str], - reduced_clp_labels: list[str], - reduced_clps: np.ndarray | list[np.ndarray], - global_axis: np.ndarray, -) -> Any: ... -def index_dependent(model: KineticSpectrumModel) -> Any: ... # noqa: F811 -def grouped(model: KineticSpectrumModel) -> bool: ... # noqa: F811 - -class KineticSpectrumModel(KineticImageModel): - dataset: Mapping[str, KineticSpectrumDatasetDescriptor] - megacomplex: Mapping[str, KineticImageMegacomplex] - @property - def equal_area_penalties(self) -> list[EqualAreaPenalty]: ... - @property - def shape(self) -> Mapping[str, SpectralShape]: ... - @property - def spectral_constraints(self) -> list[SpectralConstraint]: ... - @property - def spectral_relations(self) -> list[SpectralRelation]: ... - def has_matrix_constraints_function(self) -> bool: ... - def constrain_matrix_function( - self, parameters: ParameterGroup, clp_labels: list[str], matrix: np.ndarray, index: float - ) -> tuple[list[str], np.ndarray]: ... - def retrieve_clp_function( - self, - parameters: ParameterGroup, - clp_labels: list[str], - reduced_clp_labels: list[str], - reduced_clps: np.ndarray | list[np.ndarray], - global_axis: np.ndarray, - ) -> np.ndarray | list[np.ndarray]: ... - def has_additional_penalty_function(self) -> bool: ... - def additional_penalty_function( - self, - parameters: ParameterGroup, - clp_labels: list[str] | list[list[str]], - clps: np.ndarray, - global_axis: np.ndarray, - ) -> np.ndarray: ... - @staticmethod - def global_matrix(dataset, axis) -> tuple[None, None] | tuple[list[str], np.ndarray]: ... - @staticmethod - def matrix( # type: ignore[override] - dataset_descriptor: KineticSpectrumDatasetDescriptor = ..., - axis=..., - index=..., - irf=..., - ) -> tuple[None, None] | tuple[list[Any], np.ndarray]: ... diff --git a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_result.py b/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_result.py deleted file mode 100644 index fa647d316..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/kinetic_spectrum_result.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -import xarray as xr - -from glotaran.analysis.problem import Problem -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.builtin.models.kinetic_image.kinetic_baseline_megacomplex import ( - KineticBaselineMegacomplex, -) -from glotaran.builtin.models.kinetic_image.kinetic_image_result import ( - retrieve_decay_associated_data, -) -from glotaran.builtin.models.kinetic_image.kinetic_image_result import retrieve_irf -from glotaran.builtin.models.kinetic_image.kinetic_image_result import ( - retrieve_species_associated_data, -) -from glotaran.builtin.models.kinetic_spectrum.coherent_artifact_megacomplex import ( - CoherentArtifactMegacomplex, -) -from glotaran.builtin.models.kinetic_spectrum.spectral_irf import IrfSpectralMultiGaussian - - -def finalize_kinetic_spectrum_result(model, problem: Problem, data: dict[str, xr.Dataset]): - - for label, dataset in data.items(): - - dataset_model = problem.filled_dataset_descriptors[label] - global_dimension = dataset_model.get_global_dimension() - model_dimension = dataset_model.get_model_dimension() - - if any( - isinstance(megacomplex, KineticBaselineMegacomplex) - for megacomplex in dataset_model.megacomplex - ): - dataset["baseline"] = dataset.clp.sel(clp_label=f"{dataset_model.label}_baseline") - - retrieve_species_associated_data(problem.model, dataset, dataset_model, "spectra") - - retrieve_decay_associated_data(problem.model, dataset, dataset_model, "spectra") - - irf = dataset_model.irf - if isinstance(irf, IrfMultiGaussian): - if isinstance(irf.center, list): - dataset["irf_center"] = irf.center[0].value - dataset["irf_width"] = irf.width[0].value - else: - dataset["irf_center"] = irf.center.value - dataset["irf_width"] = irf.width.value - elif isinstance(irf, IrfSpectralMultiGaussian): - - dataset["irf"] = ( - ("time"), - irf.calculate(0, dataset.coords["spectral"], dataset.coords["time"]), - ) - - if irf.dispersion_center: - for i, dispersion in enumerate( - irf.calculate_dispersion(dataset.coords["spectral"].values) - ): - dataset[f"center_dispersion_{i+1}"] = ( - global_dimension, - dispersion, - ) - else: - retrieve_irf(problem.model, dataset, dataset_model, "images") - - if any( - isinstance(megacomplex, CoherentArtifactMegacomplex) - for megacomplex in dataset_model.megacomplex - ): - coherent_artifact = [ - c for c in dataset_model.megacomplex if isinstance(c, CoherentArtifactMegacomplex) - ][0] - dataset.coords["coherent_artifact_order"] = list(range(1, coherent_artifact.order + 1)) - dataset["coherent_artifact_concentration"] = ( - (model_dimension, "coherent_artifact_order"), - dataset.matrix.sel(clp_label=coherent_artifact.compartments()).values, - ) - dataset["coherent_artifact_associated_spectra"] = ( - (global_dimension, "coherent_artifact_order"), - dataset.clp.sel(clp_label=coherent_artifact.compartments()).values, - ) diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.py b/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.py deleted file mode 100644 index 6b3e411b5..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.py +++ /dev/null @@ -1,115 +0,0 @@ -"""This package contains compartment constraint items.""" -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import List -from typing import Tuple - -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed - -if TYPE_CHECKING: - from typing import Any - - import numpy as np - - from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import ( - KineticSpectrumModel, - ) - - -@model_attribute( - properties={ - "compartment": str, - "interval": List[Tuple[float, float]], - }, - has_type=True, - no_label=True, -) -class OnlyConstraint: - """A only constraint sets the calculated matrix row of a compartment to 0 - outside the given intervals.""" - - def applies(self, index: Any) -> bool: - """ - Returns true if the index is in one of the intervals. - - Parameters - ---------- - index : - - Returns - ------- - applies : bool - - """ - - def applies(interval): - return interval[0] <= index <= interval[1] - - if isinstance(self.interval, tuple): - return applies(self.interval) - return not any([applies(i) for i in self.interval]) - - -@model_attribute( - properties={ - "compartment": str, - "interval": List[Tuple[float, float]], - }, - has_type=True, - no_label=True, -) -class ZeroConstraint: - """A zero constraint sets the calculated matrix row of a compartment to 0 - in the given intervals.""" - - def applies(self, index: Any) -> bool: - """ - Returns true if the indexx is in one of the intervals. - - Parameters - ---------- - index : - - Returns - ------- - applies : bool - - """ - - def applies(interval): - return interval[0] <= index <= interval[1] - - if isinstance(self.interval, tuple): - return applies(self.interval) - return any([applies(i) for i in self.interval]) - - -@model_attribute_typed( - types={ - "only": OnlyConstraint, - "zero": ZeroConstraint, - }, - no_label=True, -) -class SpectralConstraint: - """A compartment constraint is applied on one compartment on one or many - intervals on the estimated axis type. - - There are three types: zero, equal and equal area. See the documentation of - the respective classes for details. - """ - - pass - - -def apply_spectral_constraints( - model: KineticSpectrumModel, clp_labels: list[str], matrix: np.ndarray, index: float -) -> tuple[list[str], np.ndarray]: - for constraint in model.spectral_constraints: - if isinstance(constraint, (OnlyConstraint, ZeroConstraint)) and constraint.applies(index): - idx = [not label == constraint.compartment for label in clp_labels] - clp_labels = [label for label in clp_labels if label != constraint.compartment] - matrix = matrix[:, idx] - return (clp_labels, matrix) diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.pyi b/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.pyi deleted file mode 100644 index ba9cd7a15..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_constraints.pyi +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import numpy as np - -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import KineticSpectrumModel -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed - -class OnlyConstraint: - @property - def compartment(self) -> str: ... - @property - def interval(self) -> list[tuple[float, float]]: ... - def applies(self, index: Any) -> bool: ... - -class ZeroConstraint: - @property - def compartment(self) -> str: ... - @property - def interval(self) -> list[tuple[float, float]]: ... - def applies(self, index: Any) -> bool: ... - -class SpectralConstraint: ... # noqa: E701 - -def apply_spectral_constraints( - model: KineticSpectrumModel, clp_labels: list[str], matrix: np.ndarray, index: float -) -> tuple[list[str], np.ndarray]: ... diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_irf.py b/glotaran/builtin/models/kinetic_spectrum/spectral_irf.py deleted file mode 100644 index 1e4578ceb..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_irf.py +++ /dev/null @@ -1,96 +0,0 @@ -import typing - -import numpy as np - -from glotaran.builtin.models.kinetic_image.irf import Irf -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.model import model_attribute -from glotaran.parameter import Parameter - - -@model_attribute( - properties={ - "dispersion_center": {"type": Parameter, "allow_none": True}, - "center_dispersion": {"type": typing.List[Parameter], "default": []}, - "width_dispersion": {"type": typing.List[Parameter], "default": []}, - "model_dispersion_with_wavenumber": {"type": bool, "default": False}, - }, - has_type=True, -) -class IrfSpectralMultiGaussian(IrfMultiGaussian): - """ - Represents a gaussian IRF. - - One width and one center is a single gauss. - - One center and multiple widths is a multiple gaussian. - - Multiple center and multiple widths is Double-, Triple- , etc. Gaussian. - - Parameters - ---------- - - label: - label of the irf - center: - one or more center of the irf as parameter indices - width: - one or more widths of the gaussian as parameter index - center_dispersion: - polynomial coefficients for the dispersion of the - center as list of parameter indices. None for no dispersion. - width_dispersion: - polynomial coefficients for the dispersion of the - width as parameter indices. None for no dispersion. - - """ - - def parameter(self, global_index: int, global_axis: np.ndarray): - centers, widths, scale, shift, backsweep, backsweep_period = super().parameter( - global_index, global_axis - ) - - index = global_axis[global_index] if global_index is not None else None - - if self.dispersion_center is not None: - dist = ( - (1e3 / index - 1e3 / self.dispersion_center) - if self.model_dispersion_with_wavenumber - else (index - self.dispersion_center) / 100 - ) - - if len(self.center_dispersion) != 0: - if self.dispersion_center is None: - raise Exception(self, f'No dispersion center defined for irf "{self.label}"') - for i, disp in enumerate(self.center_dispersion): - centers += disp * np.power(dist, i + 1) - - if len(self.width_dispersion) != 0: - if self.dispersion_center is None: - raise Exception(self, f'No dispersion center defined for irf "{self.label}"') - for i, disp in enumerate(self.width_dispersion): - widths = widths + disp * np.power(dist, i + 1) - - return centers, widths, scale, shift, backsweep, backsweep_period - - def calculate_dispersion(self, axis): - dispersion = [] - for index, _ in enumerate(axis): - center, _, _, _, _, _ = self.parameter(index, axis) - dispersion.append(center) - return np.asarray(dispersion).T - - -@model_attribute( - properties={ - "center": Parameter, - "width": Parameter, - }, - has_type=True, -) -class IrfSpectralGaussian(IrfSpectralMultiGaussian): - pass - - -Irf.add_type("spectral-multi-gaussian", IrfSpectralMultiGaussian) -Irf.add_type("spectral-gaussian", IrfSpectralGaussian) diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_irf.pyi b/glotaran/builtin/models/kinetic_spectrum/spectral_irf.pyi deleted file mode 100644 index 48b70879f..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_irf.pyi +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from glotaran.builtin.models.kinetic_image.irf import IrfMultiGaussian -from glotaran.parameter import Parameter - -class IrfSpectralMultiGaussian(IrfMultiGaussian): - @property - def dispersion_center(self) -> Parameter: ... - @property - def center_dispersion(self) -> list[Parameter]: ... - @property - def width_dispersion(self) -> list[Parameter]: ... - @property - def model_dispersion_with_wavenumber(self) -> bool: ... - def parameter(self, index: Any): ... - def calculate_dispersion(self, axis: Any): ... - -class IrfSpectralGaussian(IrfSpectralMultiGaussian): - @property - def center(self) -> Parameter: ... - @property - def width(self) -> Parameter: ... - -class IrfGaussianCoherentArtifact(IrfSpectralGaussian): - @property - def coherent_artifact_order(self) -> int: ... - @property - def coherent_artifact_width(self) -> Parameter: ... - def clp_labels(self): ... - def calculate_coherent_artifact(self, axis: Any): ... diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_matrix.py b/glotaran/builtin/models/kinetic_spectrum/spectral_matrix.py deleted file mode 100644 index 5b0dce3cf..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_matrix.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Glotaran Spectral Matrix""" - -import numpy as np - - -def spectral_matrix(dataset, axis): - """Calculates the matrix. - - Parameters - ---------- - matrix : np.array - The preallocated matrix. - - compartment_order : list(str) - A list of compartment labels to map compartments to indices in the - matrix. - - parameter : glotaran.model.ParameterGroup - - """ - if dataset.initial_concentration is None: - return None, None - shape_compartments = list(dataset.shape) - compartments = [ - c for c in dataset.initial_concentration.compartments if c in shape_compartments - ] - matrix = np.zeros((axis.size, len(compartments))) - for i, comp in enumerate(compartments): - shapes = dataset.shape[comp] - if not isinstance(shapes, list): - shapes = [shapes] - for shape in shapes: - matrix[:, i] += shape.calculate(axis) - return compartments, matrix diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.py b/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.py deleted file mode 100644 index 6956a7b6e..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.py +++ /dev/null @@ -1,161 +0,0 @@ -"""This package contains compartment constraint items.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import List -from typing import Tuple - -import numpy as np -import xarray as xr - -from glotaran.model import model_attribute -from glotaran.parameter import Parameter - -if TYPE_CHECKING: - from typing import Any - from typing import Sequence - - from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import ( - KineticSpectrumModel, - ) - from glotaran.parameter import ParameterGroup - - -@model_attribute( - properties={ - "source": str, - "source_intervals": List[Tuple[float, float]], - "target": str, - "target_intervals": List[Tuple[float, float]], - "parameter": Parameter, - "weight": str, - }, - no_label=True, -) -class EqualAreaPenalty: - """An equal area constraint adds a the differenc of the sum of a - compartments in the e matrix in one ore more intervals to the scaled sum - of the e matrix of one or more target compartments to residual. The additional - residual is scaled with the weight.""" - - def applies(self, index: Any) -> bool: - """ - Returns true if the index is in one of the intervals. - - Parameters - ---------- - index : - - Returns - ------- - applies : bool - - """ - - def applies(interval): - return interval[0] <= index <= interval[1] - - if isinstance(self.interval, tuple): - return applies(self.interval) - return any([applies(i) for i in self.interval]) - - -def has_spectral_penalties(model: KineticSpectrumModel) -> bool: - return len(model.equal_area_penalties) != 0 - - -def apply_spectral_penalties( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - clps: dict[str, list[np.ndarray]], - matrices: dict[str, np.ndarray | list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, -) -> np.ndarray: - - penalties = [] - for penalty in model.equal_area_penalties: - - penalty = penalty.fill(model, parameters) - source_area = _get_area( - model.index_dependent(), - model.global_dimension, - clp_labels, - clps, - data, - group_tolerance, - penalty.source_intervals, - penalty.source, - ) - - target_area = _get_area( - model.index_dependent(), - model.global_dimension, - clp_labels, - clps, - data, - group_tolerance, - penalty.target_intervals, - penalty.target, - ) - - area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) - penalties.append(area_penalty * penalty.weight) - return np.asarray(penalties) - - -def _get_area( - index_dependent: bool, - global_dimension: str, - clp_labels: dict[str, list[list[str]]], - clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, - intervals: list[tuple[float, float]], - compartment: str, -) -> np.ndarray: - area = [] - area_indices = [] - - for label, dataset in data.items(): - global_axis = dataset.coords[global_dimension] - for interval in intervals: - if interval[0] > global_axis[-1]: - # interval not in this dataset - continue - - start_idx, end_idx = _get_idx_from_interval(interval, global_axis) - for i in range(start_idx, end_idx + 1): - index_clp_labels = clp_labels[label][i] if index_dependent else clp_labels[label] - if compartment in index_clp_labels: - area.append(clps[label][i][index_clp_labels.index(compartment)]) - area_indices.append(global_axis[i]) - - return np.asarray(area) # TODO: normalize for distance on global axis - - -def _get_idx_from_interval( - interval: tuple[float, float], axis: Sequence[float] | np.ndarray -) -> tuple[int, int]: - """Retrieves start and end index of an interval on some axis - - Parameters - ---------- - interval : A tuple of floats with begin and end of the interval - axis : Array like object which can be cast to np.array - - Returns - ------- - start, end : tuple of int - - """ - axis_array = np.array(axis) - start = np.abs(axis_array - interval[0]).argmin() if not np.isinf(interval[0]) else 0 - end = ( - np.abs(axis_array - interval[1]).argmin() - if not np.isinf(interval[1]) - else axis_array.size - 1 - ) - return start, end diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.pyi b/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.pyi deleted file mode 100644 index eb3d7c44c..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_penalties.pyi +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import numpy as np - -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import KineticSpectrumModel -from glotaran.model import model_attribute -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup - -class EqualAreaPenalty: - @property - def compartment(self) -> str: ... - @property - def interval(self) -> list[tuple[float, float]]: ... - @property - def target(self) -> str: ... - @property - def parameter(self) -> Parameter: ... - @property - def weight(self) -> str: ... - def applies(self, index: Any) -> bool: ... - -def has_spectral_penalties(model: KineticSpectrumModel) -> bool: ... -def apply_spectral_penalties( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: list[str] | list[list[str]], - clps: np.ndarray, - global_axis: np.ndarray, -) -> np.ndarray: ... diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_relations.py b/glotaran/builtin/models/kinetic_spectrum/spectral_relations.py deleted file mode 100644 index e3aa372d0..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_relations.py +++ /dev/null @@ -1,125 +0,0 @@ -""" Glotaran Spectral Relation """ -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING -from typing import List -from typing import Tuple - -import numpy as np -import xarray as xr - -from glotaran.model import model_attribute -from glotaran.parameter import Parameter - -if TYPE_CHECKING: - from typing import Any - - from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import ( - KineticSpectrumModel, - ) - from glotaran.parameter import ParameterGroup - - -@model_attribute( - properties={ - "compartment": str, - "target": str, - "parameter": Parameter, - "interval": List[Tuple[float, float]], - }, - no_label=True, -) -class SpectralRelation: - def applies(self, index: Any) -> bool: - """ - Returns true if the index is in one of the intervals. - - Parameters - ---------- - index : - - Returns - ------- - applies : bool - - """ - return any(interval[0] <= index <= interval[1] for interval in self.interval) - - -def create_spectral_relation_matrix( - model: KineticSpectrumModel, - dataset: str, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: - relation_matrix = np.diagflat([1.0 for _ in clp_labels]) - - idx_to_delete = [] - for relation in model.spectral_relations: - if relation.target in clp_labels and relation.applies(index): - - if relation.compartment not in clp_labels: - warnings.warn( - "Relation between compartments '{relation.compartment}' and " - f"'{relation.target}' cannot be applied for '{dataset}'. " - f" '{relation.source}' not in dataset" - ) - continue - - relation = relation.fill(model, parameters) - source_idx = clp_labels.index(relation.compartment) - target_idx = clp_labels.index(relation.target) - relation_matrix[target_idx, source_idx] = relation.parameter - idx_to_delete.append(target_idx) - - clp_labels = [label for i, label in enumerate(clp_labels) if i not in idx_to_delete] - relation_matrix = np.delete(relation_matrix, idx_to_delete, axis=1) - return (clp_labels, relation_matrix) - - -def apply_spectral_relations( - model: KineticSpectrumModel, - dataset: str, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: - - if not model.spectral_relations: - return (clp_labels, matrix) - - reduced_clp_labels, relation_matrix = create_spectral_relation_matrix( - model, dataset, parameters, clp_labels, matrix, index - ) - - reduced_matrix = matrix @ relation_matrix - - return (reduced_clp_labels, reduced_matrix) - - -def retrieve_related_clps( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], -) -> dict[str, list[np.ndarray]]: - - for relation in model.spectral_relations: - relation = relation.fill(model, parameters) - for label, dataset_clp_labels in clp_labels.items(): - for i, index in enumerate(data[label].coords[model.global_dimension]): - if ( - relation.target in dataset_clp_labels[i] - and relation.compartment in dataset_clp_labels[i] - and relation.applies(index) - ): - target_idx = dataset_clp_labels[i].index(relation.target) - source_idx = dataset_clp_labels[i].index(relation.compartment) - clps[label][i][target_idx] = clps[label][i][source_idx] * relation.parameter - - return clps diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_relations.pyi b/glotaran/builtin/models/kinetic_spectrum/spectral_relations.pyi deleted file mode 100644 index 3c8a62da0..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_relations.pyi +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import numpy as np - -from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import KineticSpectrumModel -from glotaran.model import model_attribute -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup - -class SpectralRelation: - @property - def compartment(self) -> str: ... - @property - def target(self) -> str: ... - @property - def parameter(self) -> Parameter: ... - @property - def interval(self) -> list[tuple[float, float]]: ... - def applies(self, index: Any) -> bool: ... - -def create_spectral_relation_matrix( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: ... -def apply_spectral_relations( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: list[str], - matrix: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: ... -def retrieve_related_clps( - model: KineticSpectrumModel, - parameters: ParameterGroup, - clp_labels: list[str], - clps: np.ndarray, - index: float, -) -> tuple[list[str], np.ndarray]: ... diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_shape.py b/glotaran/builtin/models/kinetic_spectrum/spectral_shape.py deleted file mode 100644 index 8bd9a8b3d..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_shape.py +++ /dev/null @@ -1,89 +0,0 @@ -"""This package contains the spectral shape item.""" - -import numpy as np - -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed -from glotaran.parameter import Parameter - - -@model_attribute( - properties={ - "amplitude": Parameter, - "location": Parameter, - "width": Parameter, - }, - has_type=True, -) -class SpectralShapeGaussian: - """A gaussian spectral shape""" - - def calculate(self, axis: np.ndarray) -> np.ndarray: - """calculate calculates the shape. - - Parameters - ---------- - axis: np.ndarray - The axis to calculate the shape on. - - Returns - ------- - shape: numpy.ndarray - - """ - return self.amplitude * np.exp( - -np.log(2) * np.square(2 * (axis - self.location) / self.width) - ) - - -@model_attribute(properties={}, has_type=True) -class SpectralShapeOne: - """A gaussian spectral shape""" - - def calculate(self, axis: np.ndarray) -> np.ndarray: - """calculate calculates the shape. - - Parameters - ---------- - axis: np.ndarray - The axies to calculate the shape on. - - Returns - ------- - shape: numpy.ndarray - - """ - return np.ones(axis.shape[0]) - - -@model_attribute(properties={}, has_type=True) -class SpectralShapeZero: - """A gaussian spectral shape""" - - def calculate(self, axis: np.ndarray) -> np.ndarray: - """calculate calculates the shape. - - Only works after calling calling 'fill'. - - Parameters - ---------- - axis: np.ndarray - The axies to calculate the shape on. - - Returns - ------- - shape: numpy.ndarray - - """ - return np.zeros(axis.shape[0]) - - -@model_attribute_typed( - types={ - "gaussian": SpectralShapeGaussian, - "one": SpectralShapeOne, - "zero": SpectralShapeZero, - } -) -class SpectralShape: - """Base class for spectral shapes""" diff --git a/glotaran/builtin/models/kinetic_spectrum/spectral_shape.pyi b/glotaran/builtin/models/kinetic_spectrum/spectral_shape.pyi deleted file mode 100644 index 18aa5c43b..000000000 --- a/glotaran/builtin/models/kinetic_spectrum/spectral_shape.pyi +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed -from glotaran.parameter import Parameter - -class SpectralShapeGaussian: - @property - def amplitude(self) -> Parameter: ... - @property - def location(self) -> Parameter: ... - @property - def width(self) -> Parameter: ... - def calculate(self, axis: np.ndarray) -> np.ndarray: ... - -class SpectralShapeOne: - def calculate(self, axis: np.ndarray) -> np.ndarray: ... - -class SpectralShapeZero: - def calculate(self, axis: np.ndarray) -> np.ndarray: ... - -class SpectralShape: - @classmethod - def add_type(cls, type_name: str, attribute_type: type) -> None: ... diff --git a/glotaran/builtin/models/spectral/__init__.py b/glotaran/builtin/models/spectral/__init__.py deleted file mode 100644 index 8f658b1b5..000000000 --- a/glotaran/builtin/models/spectral/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from glotaran.builtin.models.spectral.spectral_model import SpectralModel diff --git a/glotaran/builtin/models/spectral/spectral_megacomplex.py b/glotaran/builtin/models/spectral/spectral_megacomplex.py deleted file mode 100644 index 7e3603552..000000000 --- a/glotaran/builtin/models/spectral/spectral_megacomplex.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import Dict - -import numpy as np -import xarray as xr - -from glotaran.model import DatasetDescriptor -from glotaran.model import Megacomplex -from glotaran.model import ModelError -from glotaran.model import megacomplex - - -@megacomplex( - "spectral", - properties={ - "shape": Dict[str, str], - }, -) -class SpectralMegacomplex(Megacomplex): - def calculate_matrix( - self, - dataset_model: DatasetDescriptor, - indices: dict[str, int], - **kwargs, - ): - - compartments = [] - for compartment in self.shape: - if compartment in compartments: - raise ModelError(f"More then one shape defined for compartment '{compartment}'") - compartments.append(compartment) - - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coords()[model_dimension] - - dim1 = model_axis.size - dim2 = len(self.shape) - matrix = np.zeros((dim1, dim2)) - - for i, shape in enumerate(self.shape.values()): - matrix[:, i] += shape.calculate(model_axis.values) - return xr.DataArray( - matrix, coords=((model_dimension, model_axis.data), ("clp_label", compartments)) - ) - - def index_dependent(self, dataset: DatasetDescriptor) -> bool: - return False diff --git a/glotaran/builtin/models/spectral/spectral_model.py b/glotaran/builtin/models/spectral/spectral_model.py deleted file mode 100644 index 018c37080..000000000 --- a/glotaran/builtin/models/spectral/spectral_model.py +++ /dev/null @@ -1,28 +0,0 @@ -from glotaran.builtin.models.kinetic_image.kinetic_image_dataset_descriptor import ( - KineticImageDatasetDescriptor, -) -from glotaran.builtin.models.spectral.shape import SpectralShape -from glotaran.builtin.models.spectral.spectral_megacomplex import SpectralMegacomplex -from glotaran.builtin.models.spectral.spectral_result import finalize_spectral_result -from glotaran.model import Model -from glotaran.model import model - - -@model( - "spectral-model", - attributes={ - "shape": SpectralShape, - }, - dataset_type=KineticImageDatasetDescriptor, - default_megacomplex_type="spectral", - megacomplex_types={ - "spectral": SpectralMegacomplex, - }, - model_dimension="spectral", - global_dimension="time", - grouped=False, - index_dependent=False, - finalize_data_function=finalize_spectral_result, -) -class SpectralModel(Model): - pass diff --git a/glotaran/builtin/models/spectral/spectral_result.py b/glotaran/builtin/models/spectral/spectral_result.py deleted file mode 100644 index d9b932def..000000000 --- a/glotaran/builtin/models/spectral/spectral_result.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -import xarray as xr - -from glotaran.analysis.problem import Problem -from glotaran.builtin.models.spectral.spectral_megacomplex import SpectralMegacomplex - - -def finalize_spectral_result(model, problem: Problem, data: dict[str, xr.Dataset]): - - for label, dataset in data.items(): - - dataset_model = problem.filled_dataset_descriptors[label] - - retrieve_spectral_data(problem.model, dataset, dataset_model) - - -def retrieve_spectral_data(model, dataset, dataset_model): - spectral_compartments = [] - for megacomplex in dataset_model.megacomplex: - if isinstance(megacomplex, SpectralMegacomplex): - spectral_compartments += [ - compartment - for compartment in megacomplex.shape - if compartment not in spectral_compartments - ] - - dataset.coords["species"] = spectral_compartments - dataset["species_spectra"] = ( - ( - dataset_model.get_model_dimension(), - "species", - ), - dataset.matrix.sel(clp_label=spectral_compartments).values, - ) - dataset["species_associated_concentrations"] = ( - ( - dataset_model.get_global_dimension(), - "species", - ), - dataset.clp.sel(clp_label=spectral_compartments).data, - ) diff --git a/glotaran/deprecation/modules/test/test_glotaran_root.py b/glotaran/deprecation/modules/test/test_glotaran_root.py index 59d7845e1..bb02eab34 100644 --- a/glotaran/deprecation/modules/test/test_glotaran_root.py +++ b/glotaran/deprecation/modules/test/test_glotaran_root.py @@ -56,24 +56,30 @@ def test_deprecation_warning_on_call_test_helper_no_warn(): def test_read_model_from_yaml(): """read_model_from_yaml raises warning""" + yaml = """ + type: kinetic-spectrum + megacomplex: {} + """ result = deprecation_warning_on_call_test_helper( - read_model_from_yaml, args=["type: kinetic-spectrum"], raise_exception=True + read_model_from_yaml, args=[yaml], raise_exception=True ) assert isinstance(result, Model) - assert result.model_type == "kinetic-spectrum" def test_read_model_from_yaml_file(tmp_path: Path): """read_model_from_yaml_file raises warning""" + yaml = """ + type: kinetic-spectrum + megacomplex: {} + """ model_file = tmp_path / "model.yaml" - model_file.write_text("type: kinetic-spectrum") + model_file.write_text(yaml) result = deprecation_warning_on_call_test_helper( read_model_from_yaml_file, args=[str(model_file)], raise_exception=True ) assert isinstance(result, Model) - assert result.model_type == "kinetic-spectrum" def test_read_parameters_from_csv_file(tmp_path: Path): diff --git a/glotaran/deprecation/modules/test/test_model_base_model.py b/glotaran/deprecation/modules/test/test_model_base_model.py deleted file mode 100644 index c91413e67..000000000 --- a/glotaran/deprecation/modules/test/test_model_base_model.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Test deprecated functionality in 'glotaran.model.base_model'.""" -from __future__ import annotations - -import pytest - -from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper -from glotaran.model.base_model import Model -from glotaran.model.test.test_model import MockModel - - -def test_model_simulate_method(): - """Model.simulate raises deperecation warning""" - deprecation_warning_on_call_test_helper(Model().simulate) - - -def test_model_index_dependent_method(): - """Model.index_dependent raises deperecation warning""" - deprecation_warning_on_call_test_helper(MockModel().index_dependent) - - -def test_model_global_dimension_property(): - """Model.global_dimension raises deperecation warning""" - with pytest.warns(DeprecationWarning): - MockModel().global_dimension - - -def test_model_model_dimension_property(): - """Model.model_dimension raises deperecation warning""" - with pytest.warns(DeprecationWarning): - MockModel().model_dimension diff --git a/glotaran/deprecation/modules/test/test_model_dataset_deescriptor.py b/glotaran/deprecation/modules/test/test_model_dataset_deescriptor.py deleted file mode 100644 index 2136f9dfc..000000000 --- a/glotaran/deprecation/modules/test/test_model_dataset_deescriptor.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Test deprecated functionality in 'glotaran.model.base_model'.""" -from __future__ import annotations - -from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper -from glotaran.model.dataset_descriptor import DatasetDescriptor - - -def test_dataset_descriptor_overwrite_index_dependent_method(): - """DatasetDescriptor.overwrite_index_dependent raises deperecation warning""" - deprecation_warning_on_call_test_helper(DatasetDescriptor().overwrite_index_dependent) diff --git a/glotaran/deprecation/modules/test/test_project_sheme.py b/glotaran/deprecation/modules/test/test_project_sheme.py index b9ba91d6a..93ba18793 100644 --- a/glotaran/deprecation/modules/test/test_project_sheme.py +++ b/glotaran/deprecation/modules/test/test_project_sheme.py @@ -16,8 +16,17 @@ def test_Scheme_from_yaml_file_method(tmp_path: Path): """Create Scheme from file.""" scheme_path = tmp_path / "scheme.yml" + model_yml_str = """ + megacomplex: + m1: + type: decay + k_matrix: [] + dataset: + dataset1: + megacomplex: [m1] + """ model_path = tmp_path / "model.yml" - model_path.write_text("type: kinetic-spectrum\ndataset:\n dataset1:\n megacomplex: []") + model_path.write_text(model_yml_str) parameter_path = tmp_path / "parameters.yml" parameter_path.write_text("[1.0, 67.0]") diff --git a/glotaran/examples/sequential.py b/glotaran/examples/sequential.py index c64068fcf..c50d5cf9c 100644 --- a/glotaran/examples/sequential.py +++ b/glotaran/examples/sequential.py @@ -1,10 +1,13 @@ import numpy as np +import xarray as xr from glotaran.analysis.simulation import simulate -from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Model from glotaran.parameter import ParameterGroup -sim_model = KineticSpectrumModel.from_dict( +sim_model = Model.from_dict( { "initial_concentration": { "j1": { @@ -23,24 +26,33 @@ }, "megacomplex": { "m1": { + "type": "decay", "k_matrix": ["k1"], - } + }, + "m2": { + "type": "spectral", + "shape": { + "s1": "sh1", + "s2": "sh2", + "s3": "sh3", + }, + }, }, "shape": { "sh1": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", }, "sh3": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": "shapes.amps.3", "location": "shapes.locs.3", "width": "shapes.width.3", @@ -53,15 +65,12 @@ "dataset1": { "initial_concentration": "j1", "megacomplex": ["m1"], - "shape": { - "s1": "sh1", - "s2": "sh2", - "s3": "sh3", - }, + "global_megacomplex": ["m2"], "irf": "irf1", } }, - } + }, + megacomplex_types={"decay": DecayMegacomplex, "spectral": SpectralMegacomplex}, ) wanted_parameter = ParameterGroup.from_dict( @@ -95,8 +104,8 @@ } ) -_time = np.arange(-1, 20, 0.01) -_spectral = np.arange(600, 700, 1.4) +_time = xr.DataArray(np.arange(-1, 20, 0.01)) +_spectral = xr.DataArray(np.arange(600, 700, 1.4)) dataset = simulate( sim_model, @@ -107,7 +116,7 @@ noise_std_dev=1e-2, ) -model = KineticSpectrumModel.from_dict( +model = Model.from_dict( { "initial_concentration": { "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["j.1", "j.0", "j.0"]}, @@ -123,6 +132,7 @@ }, "megacomplex": { "m1": { + "type": "decay", "k_matrix": ["k1"], } }, @@ -136,5 +146,6 @@ "irf": "irf1", } }, - } + }, + megacomplex_types={"decay": DecayMegacomplex}, ) diff --git a/glotaran/model/__init__.py b/glotaran/model/__init__.py index 00379885e..fc07bcdf3 100644 --- a/glotaran/model/__init__.py +++ b/glotaran/model/__init__.py @@ -4,22 +4,21 @@ common model items. """ -from glotaran.model.attribute import model_attribute -from glotaran.model.attribute import model_attribute_typed -from glotaran.model.base_model import Model from glotaran.model.clp_penalties import EqualAreaPenalty from glotaran.model.constraint import Constraint from glotaran.model.constraint import OnlyConstraint from glotaran.model.constraint import ZeroConstraint -from glotaran.model.dataset_descriptor import DatasetDescriptor -from glotaran.model.decorator import model +from glotaran.model.dataset_model import DatasetModel +from glotaran.model.item import model_item +from glotaran.model.item import model_item_typed from glotaran.model.megacomplex import Megacomplex from glotaran.model.megacomplex import megacomplex +from glotaran.model.model import Model from glotaran.model.relation import Relation from glotaran.model.util import ModelError from glotaran.model.weight import Weight -from glotaran.plugin_system.model_registration import get_model -from glotaran.plugin_system.model_registration import is_known_model -from glotaran.plugin_system.model_registration import known_model_names -from glotaran.plugin_system.model_registration import model_plugin_table -from glotaran.plugin_system.model_registration import set_model_plugin +from glotaran.plugin_system.megacomplex_registration import get_megacomplex +from glotaran.plugin_system.megacomplex_registration import is_known_megacomplex +from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names +from glotaran.plugin_system.megacomplex_registration import megacomplex_plugin_table +from glotaran.plugin_system.megacomplex_registration import set_megacomplex_plugin diff --git a/glotaran/model/base_model.py b/glotaran/model/base_model.py deleted file mode 100644 index c33563a42..000000000 --- a/glotaran/model/base_model.py +++ /dev/null @@ -1,283 +0,0 @@ -"""A base class for global analysis models.""" -from __future__ import annotations - -import copy -from typing import TYPE_CHECKING - -from glotaran.deprecation import deprecate -from glotaran.parameter import ParameterGroup -from glotaran.utils.ipython import MarkdownStr - -if TYPE_CHECKING: - import numpy as np - import xarray as xr - - -class Model: - """A base class for global analysis models.""" - - @classmethod - def from_dict(cls, model_dict_ref: dict) -> Model: - """Creates a model from a dictionary. - - Parameters - ---------- - model_dict : - Dictionary containing the model. - """ - - model = cls() - - model_dict = copy.deepcopy(model_dict_ref) - - # iterate over items - for name, attribute in list(model_dict.items()): - - # we determine if the item is known by the model by looking for - # a setter with same name. - - if hasattr(model, f"set_{name}"): - - # get the set function - model_set = getattr(model, f"set_{name}") - - for label, item in attribute.items(): - # we retrieve the actual class from the signature - item_cls = model_set.__func__.__annotations__["item"] - - is_typed = hasattr(item_cls, "_glotaran_model_attribute_typed") - - if isinstance(item, dict): - if is_typed: - if "type" not in item and item_cls.get_default_type() is None: - raise ValueError(f"Missing type for attribute '{name}'") - item_type = item.get("type", item_cls.get_default_type()) - - types = item_cls._glotaran_model_attribute_types - if item_type not in types: - raise ValueError( - f"Unknown type '{item_type}' for attribute '{name}'" - ) - item_cls = types[item_type] - item["label"] = label - model_set(label, item_cls.from_dict(item)) - elif isinstance(item, list): - if is_typed: - if len(item) < 2 and len(item) != 1: - raise ValueError(f"Missing type for attribute '{name}'") - item_type = item[0] - types = item_cls._glotaran_model_attribute_types - - if item_type not in types: - raise ValueError( - f"Unknown type '{item_type}' for attribute '{name}'" - ) - item_cls = types[item_type] - item = [label] + item - model_set(label, item_cls.from_list(item)) - del model_dict[name] - - elif hasattr(model, f"add_{name}"): - - # get the set function - add = getattr(model, f"add_{name}") - - # we retrieve the actual class from the signature - for item in attribute: - item_cls = add.__func__.__annotations__["item"] - is_typed = hasattr(item_cls, "_glotaran_model_attribute_typed") - if isinstance(item, dict): - if is_typed: - if "type" not in item: - raise ValueError(f"Missing type for attribute '{name}'") - item_type = item["type"] - - if item_type not in item_cls._glotaran_model_attribute_types: - raise ValueError( - f"Unknown type '{item_type}' for attribute '{name}'" - ) - item_cls = item_cls._glotaran_model_attribute_types[item_type] - add(item_cls.from_dict(item)) - elif isinstance(item, list): - if is_typed: - if len(item) < 2 and len(item) != 1: - raise ValueError(f"Missing type for attribute '{name}'") - item_type = ( - item[1] - if len(item) != 1 and hasattr(item_cls, "label") - else item[0] - ) - - if item_type not in item_cls._glotaran_model_attribute_types: - raise ValueError( - f"Unknown type '{item_type}' for attribute '{name}'" - ) - item_cls = item_cls._glotaran_model_attribute_types[item_type] - add(item_cls.from_list(item)) - del model_dict[name] - - return model - - @property - def model_type(self) -> str: - """The type of the model as human readable string.""" - return self._model_type - - def problem_list(self, parameters: ParameterGroup = None) -> list[str]: - """ - Returns a list with all problems in the model and missing parameters if specified. - - Parameters - ---------- - - parameter : - The parameter to validate. - """ - problems = [] - - attrs = getattr(self, "_glotaran_model_attributes") - for attr in attrs: - attr = getattr(self, attr) - if isinstance(attr, list): - for item in attr: - problems += item.validate(self, parameters=parameters) - else: - for _, item in attr.items(): - problems += item.validate(self, parameters=parameters) - - return problems - - @deprecate( - deprecated_qual_name_usage="glotaran.model.base_model.Model.simulate", - new_qual_name_usage="glotaran.analysis.simulation.simulate", - to_be_removed_in_version="0.6.0", - importable_indices=(2, 1), - ) - def simulate( - self, - dataset: str, - parameters: ParameterGroup, - axes: dict[str, np.ndarray] = None, - clp: np.ndarray | xr.DataArray = None, - noise: bool = False, - noise_std_dev: float = 1.0, - noise_seed: int = None, - ) -> xr.Dataset: - """Simulates the model. - - Parameters - ---------- - dataset : - Label of the dataset to simulate. - parameter : - The parameters for the simulation. - axes : - A dictionary with axes for simulation. - clp : - Conditionally linear parameters. Used instead of `model.global_matrix` if provided. - noise : - If `True` noise is added to the simulated data. - noise_std_dev : - The standard deviation of the noise. - noise_seed : - Seed for the noise. - """ - from glotaran.analysis.simulation import simulate - - return simulate( - self, - dataset, - parameters, - axes=axes, - clp=clp, - noise=noise, - noise_std_dev=noise_std_dev, - noise_seed=noise_seed, - ) - - def validate(self, parameters: ParameterGroup = None) -> str: - """ - Returns a string listing all problems in the model and missing parameters if specified. - - Parameters - ---------- - - parameter : - The parameter to validate. - """ - result = "" - - problems = self.problem_list(parameters) - if problems: - result = f"Your model has {len(problems)} problems:\n" - for p in problems: - result += f"\n * {p}" - else: - result = "Your model is valid." - return result - - def valid(self, parameters: ParameterGroup = None) -> bool: - """Returns `True` if the number problems in the model is 0, else `False` - - Parameters - ---------- - - parameter : - The parameter to validate. - """ - return len(self.problem_list(parameters)) == 0 - - def markdown( - self, - parameters: ParameterGroup = None, - initial_parameters: ParameterGroup = None, - base_heading_level: int = 1, - ) -> MarkdownStr: - """Formats the model as Markdown string. - - Parameters will be included if specified. - - Parameters - ---------- - parameter: ParameterGroup - Parameter to include. - initial_parameters: ParameterGroup - Initial values for the parameters. - base_heading_level: int - Base heading level of the markdown sections. - - E.g.: - - - If it is 1 the string will start with '# Model'. - - If it is 3 the string will start with '### Model'. - """ - base_heading = "#" * base_heading_level - attrs = getattr(self, "_glotaran_model_attributes") - string = f"{base_heading} Model\n\n" - string += f"_Type_: {self.model_type}\n\n" - - for attr in attrs: - child_attr = getattr(self, attr) - if not child_attr: - continue - - string += f"{base_heading}# {attr.replace('_', ' ').title()}\n\n" - - if isinstance(child_attr, dict): - child_attr = child_attr.values() - for item in child_attr: - item_str = item.mprint( - parameters=parameters, initial_parameters=initial_parameters - ).split("\n") - string += f"* {item_str[0]}\n" - for s in item_str[1:]: - string += f" {s}\n" - string += "\n" - return MarkdownStr(string) - - def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown.""" - return str(self.markdown(base_heading_level=3)) - - def __str__(self): - return str(self.markdown()) diff --git a/glotaran/model/base_model.pyi b/glotaran/model/base_model.pyi deleted file mode 100644 index 80a5e7fe3..000000000 --- a/glotaran/model/base_model.pyi +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Callable -from typing import Mapping -from typing import TypeVar - -import numpy as np -import xarray as xr - -from glotaran.analysis.optimize import optimize # noqa: F401 -from glotaran.analysis.simulation import simulate # noqa: F401 -from glotaran.model.dataset_descriptor import DatasetDescriptor -from glotaran.model.decorator import FinalizeFunction -from glotaran.model.weight import Weight -from glotaran.parameter import ParameterGroup -from glotaran.project.result import Result -from glotaran.project.scheme import Scheme # noqa: F401 - -_Cls = TypeVar("_Cls") - -class Model: - _model_type: str - dataset: Mapping[str, DatasetDescriptor] - megacomplex: Any - weights: Weight - model_dimension: str - global_dimension: str - global_matrix = None - finalize_data: FinalizeFunction | None = ... - grouped: Callable[[], bool] - index_dependent: Callable[[], bool] - @staticmethod - def matrix( - dataset_descriptor: DatasetDescriptor = ..., axis=..., index=... - ) -> tuple[None, None] | tuple[list[Any], np.ndarray]: ... - def add_megacomplex(self, item: Any): ... - def add_weights(self, item: Weight): ... - def get_dataset(self, label: str) -> DatasetDescriptor: ... - @classmethod - def from_dict(cls: type[_Cls], model_dict_ref: dict) -> _Cls: ... - @property - def model_type(self) -> str: ... - def simulate( # noqa: F811 - self, - dataset: str, - parameters: ParameterGroup, - axes: dict[str, np.ndarray] = ..., - clp: np.ndarray | xr.DataArray = ..., - noise: bool = ..., - noise_std_dev: float = ..., - noise_seed: int = ..., - ) -> xr.Dataset: ... - def result_from_parameter( - self, - parameters: ParameterGroup, - data: dict[str, xr.DataArray | xr.Dataset], - nnls: bool = ..., - group_atol: float = ..., - ) -> Result: ... - def problem_list(self, parameters: ParameterGroup = ...) -> list[str]: ... - def validate(self, parameters: ParameterGroup = ...) -> str: ... - def valid(self, parameters: ParameterGroup = ...) -> bool: ... - def markdown( - self, - parameters: ParameterGroup = ..., - initial_parameters: ParameterGroup = ..., - base_heading_level: int = ..., - ) -> str: ... diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py index 6956a7b6e..2906bfe25 100644 --- a/glotaran/model/clp_penalties.py +++ b/glotaran/model/clp_penalties.py @@ -9,7 +9,7 @@ import numpy as np import xarray as xr -from glotaran.model import model_attribute +from glotaran.model.item import model_item from glotaran.parameter import Parameter if TYPE_CHECKING: @@ -22,7 +22,7 @@ from glotaran.parameter import ParameterGroup -@model_attribute( +@model_item( properties={ "source": str, "source_intervals": List[Tuple[float, float]], @@ -31,7 +31,7 @@ "parameter": Parameter, "weight": str, }, - no_label=True, + has_label=False, ) class EqualAreaPenalty: """An equal area constraint adds a the differenc of the sum of a diff --git a/glotaran/model/constraint.py b/glotaran/model/constraint.py index d6683a85f..948f581ca 100644 --- a/glotaran/model/constraint.py +++ b/glotaran/model/constraint.py @@ -3,20 +3,20 @@ from typing import TYPE_CHECKING -from glotaran.model import model_attribute -from glotaran.model import model_attribute_typed from glotaran.model.interval_property import IntervalProperty +from glotaran.model.item import model_item +from glotaran.model.item import model_item_typed if TYPE_CHECKING: from typing import Any -@model_attribute( +@model_item( properties={ "target": str, }, has_type=True, - no_label=True, + has_label=False, ) class OnlyConstraint(IntervalProperty): """A only constraint sets the calculated matrix row of a compartment to 0 @@ -38,24 +38,24 @@ def applies(self, index: Any) -> bool: return not super().applies(index) -@model_attribute( +@model_item( properties={ "target": str, }, has_type=True, - no_label=True, + has_label=False, ) class ZeroConstraint(IntervalProperty): """A zero constraint sets the calculated matrix row of a compartment to 0 in the given intervals.""" -@model_attribute_typed( +@model_item_typed( types={ "only": OnlyConstraint, "zero": ZeroConstraint, }, - no_label=True, + has_label=False, ) class Constraint: """A constraint is applied on one clp on one or many diff --git a/glotaran/model/dataset_descriptor.py b/glotaran/model/dataset_descriptor.py deleted file mode 100644 index 93c086959..000000000 --- a/glotaran/model/dataset_descriptor.py +++ /dev/null @@ -1,102 +0,0 @@ -"""The DatasetDescriptor class.""" -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Generator -from typing import List - -import xarray as xr - -from glotaran.deprecation import deprecate -from glotaran.model.attribute import model_attribute -from glotaran.parameter import Parameter - -if TYPE_CHECKING: - from glotaran.model.megacomplex import Megacomplex - - -@model_attribute( - properties={ - "megacomplex": List[str], - "megacomplex_scale": {"type": List[Parameter], "default": None, "allow_none": True}, - "scale": {"type": Parameter, "default": None, "allow_none": True}, - } -) -class DatasetDescriptor: - """A `DatasetDescriptor` describes a dataset in terms of a glotaran model. - It contains references to model items which describe the physical model for - a given dataset. - - A general dataset descriptor assigns one or more megacomplexes and a scale - parameter. - """ - - def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: - for i, megacomplex in enumerate(self.megacomplex): - scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None - yield scale, megacomplex - - def get_model_dimension(self) -> str: - if not hasattr(self, "_model_dimension"): - if len(self.megacomplex) == 0: - raise ValueError(f"No megacomplex set for dataset descriptor '{self.label}'") - if isinstance(self.megacomplex[0], str): - raise ValueError(f"Dataset descriptor '{self.label}' was not filled") - self._model_dimension = self.megacomplex[0].dimension - if any(self._model_dimension != m.dimension for m in self.megacomplex): - raise ValueError( - f"Megacomplex dimensions do not match for dataset descriptor '{self.label}'." - ) - return self._model_dimension - - def overwrite_model_dimension(self, model_dimension: str): - self._model_dimension = model_dimension - - # TODO: make explicit we only support 2 dimensions at present - # TODO: the global dimension should become a flexible index (MultiIndex) - # the user can then specify the name of the MultiIndex global dimension - # using the function overwrite_global_dimension - # e.g. in FLIM, x, y dimension may get 'flattened' to a MultiIndex 'pixel' - def get_global_dimension(self) -> str: - if not hasattr(self, "_global_dimension"): - if not hasattr(self, "_data"): - raise ValueError(f"Data not set for dataset descriptor '{self.label}'") - self._global_dimension = [ - dim for dim in self._data.data.dims if dim != self.get_model_dimension() - ][0] - return self._global_dimension - - def overwrite_global_dimension(self, global_dimension: str): - self._global_dimension = global_dimension - - def set_data(self, data: xr.Dataset) -> DatasetDescriptor: - self._data = data - return self - - def get_data(self) -> xr.Dataset: - return self._data - - def index_dependent(self) -> bool: - if hasattr(self, "_index_dependent"): - return self._index_dependent - return any(m.index_dependent(self) for m in self.megacomplex) - - def set_coords(self, coords: xr.Dataset): - self._coords = coords - - def get_coords(self) -> xr.Dataset: - if hasattr(self, "_coords"): - return self._coords - return self._data.coords - - @deprecate( - deprecated_qual_name_usage=( - "glotaran.model.dataset_descriptor.DatasetDescriptor.overwrite_index_dependent" - ), - new_qual_name_usage="", - to_be_removed_in_version="0.6.0", - importable_indices=(2, 2), - has_glotaran_replacement=False, - ) - def overwrite_index_dependent(self, index_dependent: bool): - self._index_dependent = index_dependent diff --git a/glotaran/model/dataset_descriptor.pyi b/glotaran/model/dataset_descriptor.pyi deleted file mode 100644 index 8014f0410..000000000 --- a/glotaran/model/dataset_descriptor.pyi +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TypeVar - -from glotaran.model.attribute import model_attribute # noqa: F401 -from glotaran.model.base_model import Model -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup - -_T_Model = TypeVar("_T_Model", bound=Model) -_T_DatasetDescriptor = TypeVar("_T_DatasetDescriptor", bound=DatasetDescriptor) # noqa: F821 - -class DatasetDescriptor: - megacomplex: list[str] - megacomplex_scale: list[Parameter] | None = ... - scale: Parameter | None = ... - def fill(self, model: _T_Model, parameters: ParameterGroup) -> _T_DatasetDescriptor: ... - @classmethod - def from_dict(cls: type[_T_DatasetDescriptor], values: dict) -> _T_DatasetDescriptor: ... - @classmethod - def from_list(cls: type[_T_DatasetDescriptor], values: list) -> _T_DatasetDescriptor: ... - def validate(self, model: _T_Model, parameters=...) -> list[str]: ... - def mprint_item( - self, parameters: ParameterGroup = ..., initial_parameters: ParameterGroup = ... - ) -> str: ... diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py new file mode 100644 index 000000000..2e407cc7e --- /dev/null +++ b/glotaran/model/dataset_model.py @@ -0,0 +1,162 @@ +"""The DatasetModel class.""" +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Generator + +import xarray as xr + +from glotaran.model.item import model_item +from glotaran.model.item import model_item_validator +from glotaran.parameter import Parameter + +if TYPE_CHECKING: + from glotaran.model.megacomplex import Megacomplex + from glotaran.model.model import Model + + +def create_dataset_model_type(properties: dict[str, any]) -> type: + @model_item(properties=properties) + class ModelDatasetModel(DatasetModel): + pass + + return ModelDatasetModel + + +class DatasetModel: + """A `DatasetModel` describes a dataset in terms of a glotaran model. + It contains references to model items which describe the physical model for + a given dataset. + + A general dataset descriptor assigns one or more megacomplexes and a scale + parameter. + """ + + def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + """Iterates of der dataset model's megacomplexes.""" + for i, megacomplex in enumerate(self.megacomplex): + scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None + yield scale, megacomplex + + def iterate_global_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + """Iterates of der dataset model's global megacomplexes.""" + for i, megacomplex in enumerate(self.global_megacomplex): + scale = ( + self.global_megacomplex_scale[i] + if self.global_megacomplex_scale is not None + else None + ) + yield scale, megacomplex + + def get_model_dimension(self) -> str: + """Returns the dataset model's model dimension.""" + if not hasattr(self, "_model_dimension"): + if len(self.megacomplex) == 0: + raise ValueError(f"No megacomplex set for dataset descriptor '{self.label}'") + if isinstance(self.megacomplex[0], str): + raise ValueError(f"Dataset descriptor '{self.label}' was not filled") + self._model_dimension = self.megacomplex[0].dimension + if any(self._model_dimension != m.dimension for m in self.megacomplex): + raise ValueError( + f"Megacomplex dimensions do not match for dataset descriptor '{self.label}'." + ) + return self._model_dimension + + def finalize_data(self, data: xr.Dataset): + for megacomplex in self.megacomplex: + megacomplex.finalize_data(self, data) + + def overwrite_model_dimension(self, model_dimension: str): + """Overwrites the dataset model's model dimension.""" + self._model_dimension = model_dimension + + # TODO: make explicit we only support 2 dimensions at present + # TODO: the global dimension should become a flexible index (MultiIndex) + # the user can then specify the name of the MultiIndex global dimension + # using the function overwrite_global_dimension + # e.g. in FLIM, x, y dimension may get 'flattened' to a MultiIndex 'pixel' + def get_global_dimension(self) -> str: + """Returns the dataset model's global dimension.""" + if not hasattr(self, "_global_dimension"): + if self.global_model(): + if isinstance(self.global_megacomplex[0], str): + raise ValueError(f"Dataset descriptor '{self.label}' was not filled") + self._global_dimension = self.global_megacomplex[0].dimension + if any(self._global_dimension != m.dimension for m in self.global_megacomplex): + raise ValueError( + "Global megacomplex dimensions do not " + f"match for dataset model '{self.label}'." + ) + elif hasattr(self, "_coords"): + return next(dim for dim in self._coords if dim != self.get_model_dimension()) + else: + if not hasattr(self, "_data"): + raise ValueError(f"Data not set for dataset descriptor '{self.label}'") + self._global_dimension = next( + dim for dim in self._data.data.dims if dim != self.get_model_dimension() + ) + return self._global_dimension + + def overwrite_global_dimension(self, global_dimension: str): + """Overwrites the dataset model's global dimension.""" + self._global_dimension = global_dimension + + def swap_dimensions(self): + """Swaps the dataset model's global and model dimension.""" + global_dimension = self.get_model_dimension() + model_dimension = self.get_global_dimension() + self.overwrite_global_dimension(global_dimension) + self.overwrite_model_dimension(model_dimension) + + def set_data(self, data: xr.Dataset) -> DatasetModel: + """Sets the dataset model's data.""" + self._data = data + return self + + def get_data(self) -> xr.Dataset: + """Gets the dataset model's data.""" + return self._data + + def index_dependent(self) -> bool: + """Indicates if the dataset model is index dependent.""" + if hasattr(self, "_index_dependent"): + return self._index_dependent + return any(m.index_dependent(self) for m in self.megacomplex) + + def overwrite_index_dependent(self, index_dependent: bool): + """Overrides the index dependency of the dataset""" + self._index_dependent = index_dependent + + def global_model(self) -> bool: + """Indicates if the dataset model can model the global dimension.""" + return len(self.global_megacomplex) != 0 + + def set_coordinates(self, coords: xr.Dataset): + """Sets the dataset model's coordinates.""" + self._coords = coords + + def get_coordinates(self) -> xr.Dataset: + """Gets the dataset model's coordinates.""" + if hasattr(self, "_coords"): + return self._coords + return self._data.coords + + @model_item_validator(False) + def ensure_unique_megacomplexes(self, model: Model) -> list[str]: + + megacomplexes = [model.megacomplex[m] for m in self.megacomplex if m in model.megacomplex] + types = {type(m) for m in megacomplexes} + problems = [] + + for megacomplex_type in types: + if not megacomplex_type.glotaran_unique: + continue + instances = [m for m in megacomplexes if isinstance(m, megacomplex_type)] + n = len(instances) + if n != 1: + problems.append( + f"Multiple instances of unique megacomplex type '{instances[0].type}' " + "in dataset {self.label}" + ) + + return problems diff --git a/glotaran/model/decorator.py b/glotaran/model/decorator.py deleted file mode 100644 index 347dec4b5..000000000 --- a/glotaran/model/decorator.py +++ /dev/null @@ -1,460 +0,0 @@ -"""The model decorator.""" -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Dict -from typing import List - -from glotaran.deprecation import deprecate -from glotaran.model.attribute import model_attribute_typed -from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import Constraint -from glotaran.model.dataset_descriptor import DatasetDescriptor -from glotaran.model.megacomplex import Megacomplex -from glotaran.model.relation import Relation -from glotaran.model.util import wrap_func_as_method -from glotaran.model.weight import Weight -from glotaran.plugin_system.model_registration import register_model - -if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Tuple - from typing import Type - from typing import Union - - import numpy as np - import xarray as xr - - from glotaran.analysis.problem import Problem - from glotaran.model.base_model import Model - from glotaran.parameter import ParameterGroup - - MegacomplexMatrixFunction = Callable[ - [Type[object], Type[Model], Type[DatasetDescriptor], dict[str, int], Any], - Tuple[List[str], np.ndarray], - ] - """A `MatrixFunction` calculates the matrix for a model.""" - - GlobalMatrixFunction = Callable[ - [Type[DatasetDescriptor], np.ndarray], Tuple[List[str], np.ndarray] - ] - """A `GlobalMatrixFunction` calculates the global matrix for a model.""" - - ConstrainMatrixFunction = Callable[ - [Type[Model], ParameterGroup, List[str], np.ndarray, float], - Tuple[List[str], np.ndarray], - ] - """A `ConstrainMatrixFunction` applies constraints on a matrix.""" - - RetrieveClpFunction = Callable[ - [ - Type[Model], - ParameterGroup, - Dict[str, Union[List[str], List[List[str]]]], - Dict[str, Union[List[str], List[List[str]]]], - Dict[str, List[np.ndarray]], - Dict[str, xr.Dataset], - ], - Dict[str, List[np.ndarray]], - ] - """A `RetrieveClpFunction` retrieves the full set of clp from a reduced set.""" - - FinalizeFunction = Callable[[Problem, Dict[str, xr.Dataset]], None] - """A `FinalizeFunction` gets called after optimization.""" - - PenaltyFunction = Callable[ - [ - Type[Model], - ParameterGroup, - Dict[str, Union[List[str], List[List[str]]]], - Dict[str, List[np.ndarray]], - Dict[str, Union[np.ndarray, List[np.ndarray]]], - Dict[str, xr.Dataset], - float, - ], - np.ndarray, - ] - """A `PenaltyFunction` calculates additional penalties for the optimization.""" - - -def model( - model_type: str, - attributes: dict[str, Any] = None, - dataset_type: type[DatasetDescriptor] = DatasetDescriptor, - default_megacomplex_type: str = None, - megacomplex_types: dict[str, Megacomplex] | type[Megacomplex] = None, - global_matrix: GlobalMatrixFunction = None, - model_dimension: str = None, - global_dimension: str = None, - has_matrix_constraints_function: Callable[[type[Model]], bool] = None, - constrain_matrix_function: ConstrainMatrixFunction = None, - retrieve_clp_function: RetrieveClpFunction = None, - has_additional_penalty_function: Callable[[type[Model]], bool] = None, - additional_penalty_function: PenaltyFunction = None, - finalize_data_function: FinalizeFunction = None, - grouped: bool | Callable[[type[Model]], bool] = False, - index_dependent: bool | Callable[[type[Model]], bool] = False, -) -> Callable[[type[Model]], type[Model]]: - """The `@model` decorator is intended to be used on subclasses of :class:`glotaran.model.Model`. - It creates properties for the given attributes as well as functions to add access them. Also it - adds the functions (e.g. for `matrix`) to the model ensures they are added wrapped in a correct - way. - - Parameters - ---------- - model_type : str - Human readable string used by the parser to identify the correct model. - attributes : Dict[str, Any], optional - A dictionary of attribute names and types. All types must be decorated with the - :func:`glotaran.model.model_attribute` decorator, by default None. - dataset_type : Type[DatasetDescriptor], optional - A subclass of :class:`DatasetDescriptor`, by default DatasetDescriptor - megacomplex_type : Any, optional - A class for the model megacomplexes. The class must be decorated with the - :func:`glotaran.model.model_attribute` decorator, by default None - matrix : Union[MatrixFunction, IndexDependentMatrixFunction], optional - A function to calculate the matrix for the model, by default None - global_matrix : GlobalMatrixFunction, optional - A function to calculate the global matrix for the model, by default None - model_dimension : str, optional - The name of model matrix row dimension, by default None - global_dimension : str, optional - The name of model global matrix row dimension, by default None - has_matrix_constraints_function : Callable[[Type[Model]], bool], optional - True if the model as a constrain_matrix_function set, by default None - constrain_matrix_function : ConstrainMatrixFunction, optional - A function to constrain the global matrix for the model, by default None - retrieve_clp_function : RetrieveClpFunction, optional - A function to retrieve the full clp from the reduced, by default None - has_additional_penalty_function : Callable[[Type[Model]], bool], optional - True if model has a additional_penalty_function set, by default None - additional_penalty_function : PenaltyFunction, optional - A function to calculate additional penalties when optimizing the model, by default None - finalize_data_function : FinalizeFunction, optional - A function to finalize data after optimization, by default None - grouped : Union[bool, Callable[[Type[Model]], bool]], optional - True if model described a grouped problem, by default False - index_dependent : Union[bool, Callable[[Type[Model]], bool]], optional - True if model described a index dependent problem, by default False - - Returns - ------- - Callable - Returns a decorated model function - - Raises - ------ - ValueError - If model implements meth:`has_matrix_constraints_function` but not - meth:`constrain_matrix_function` and meth:`retrieve_clp_function` - ValueError - If model implements meth:`has_additional_penalty_function` but not - meth:`additional_penalty_function` - """ - - def decorator(cls): - - setattr(cls, "_model_type", model_type) - setattr(cls, "finalize_data", finalize_data_function) - - _set_constraints_functions( - cls, has_matrix_constraints_function, constrain_matrix_function, retrieve_clp_function - ) - - _set_additional_penalty_functions( - cls, has_additional_penalty_function, additional_penalty_function - ) - - _set_grouped_and_indexdependent(cls, grouped, index_dependent) - - _set_dimensions(cls, model_type, model_dimension, global_dimension) - - if global_matrix: - g_mat = wrap_func_as_method(cls, name="global_matrix")(global_matrix) - g_mat = staticmethod(g_mat) - setattr(cls, "global_matrix", g_mat) - else: - setattr(cls, "global_matrix", None) - - if not hasattr(cls, "_glotaran_model_attributes"): - setattr(cls, "_glotaran_model_attributes", {}) - else: - setattr( - cls, - "_glotaran_model_attributes", - getattr(cls, "_glotaran_model_attributes").copy(), - ) - - megacomplex_cls = _set_megacomplexes( - cls, model_type, default_megacomplex_type, megacomplex_types - ) - - # We add the standard attributes here. - if not issubclass(dataset_type, DatasetDescriptor): - raise ValueError( - f"Dataset descriptor of model {model_type} is not a subclass of DatasetDescriptor" - ) - attributes["dataset"] = dataset_type - attributes["megacomplex"] = megacomplex_cls - attributes["weights"] = Weight - attributes["relations"] = Relation - attributes["constraints"] = Constraint - attributes["clp_area_penalties"] = EqualAreaPenalty - - # Set annotations and methods for attributes - for attr_name, attr_type in attributes.items(): - - # store for internal lookups - getattr(cls, "_glotaran_model_attributes")[attr_name] = None - - # create and attach the property to class - attr_prop = _create_property_for_attribute(cls, attr_name, attr_type) - setattr(cls, attr_name, attr_prop) - - # properties with labels are implemented as dicts, whereas properties - # without as arrays. Thus the need different setters. - if getattr(attr_type, "_glotaran_has_label"): - get_item = _create_get_func(cls, attr_name, attr_type) - setattr(cls, get_item.__name__, get_item) - set_item = _create_set_func(cls, attr_name, attr_type) - setattr(cls, set_item.__name__, set_item) - - else: - add_item = _create_add_func(cls, attr_name, attr_type) - setattr(cls, add_item.__name__, add_item) - - init = _create_init_func(cls, attributes) - setattr(cls, "__init__", init) - - register_model(model_type, cls) - - return cls - - return decorator - - -def _create_init_func(cls, attributes): - @wrap_func_as_method(cls) - def __init__(self): - for attr_name, attr_item in attributes.items(): - if getattr(attr_item, "_glotaran_has_label"): - setattr(self, f"_{attr_name}", {}) - else: - setattr(self, f"_{attr_name}", []) - super(cls, self).__init__() - - return __init__ - - -def _create_add_func(cls, name, item_type): - @wrap_func_as_method(cls, name=f"add_{name}", annotations={"item": item_type}) - def add_item(self, item: item_type): - f"""Adds an `{item_type.__name__}` object. - - Parameters - ---------- - item : - The `{item_type.__name__}` item. - """ - - if not isinstance(item, item_type) and ( - not hasattr(item_type, "_glotaran_model_attribute_typed") - or not isinstance(item, tuple(item_type._glotaran_model_attribute_types.values())) - ): - raise TypeError - getattr(self, f"_{name}").append(item) - - return add_item - - -def _create_get_func(cls, name, item_type): - @wrap_func_as_method(cls, name=f"get_{name}", annotations={"return": item_type}) - def get_item(self, label: str) -> item_type: - f""" - Returns the `{item_type.__name__}` object with the given label. - - Parameters - ---------- - label : - The label of the `{item_type.__name__}` object. - """ - return getattr(self, f"_{name}")[label] - - return get_item - - -def _create_set_func(cls, name, item_type): - @wrap_func_as_method(cls, name=f"set_{name}", annotations={"item": item_type}) - def set_item(self, label: str, item: item_type): - f""" - Sets the `{item_type.__name__}` object with the given label with to the item. - - Parameters - ---------- - label : - The label of the `{item_type.__name__}` object. - item : - The `{item_type.__name__}` item. - """ - - if ( - not isinstance(item, item_type) - and ( - not hasattr(item_type, "_glotaran_model_attribute_typed") - or not isinstance(item, tuple(item_type._glotaran_model_attribute_types.values())) - ) - and not isinstance(item, Megacomplex) - ): - raise TypeError - getattr(self, f"_{name}")[label] = item - - return set_item - - -def _set_megacomplexes(cls, model_type, default_megacomplex_type, megacomplex_types): - @model_attribute_typed({}) - class MetaMegacomplex: - """This class holds all Megacomplex types defined by a model.""" - - if not isinstance(megacomplex_types, dict): - megacomplex_types = {model_type: megacomplex_types} - for name, megacomplex_type in megacomplex_types.items(): - if not issubclass(megacomplex_type, Megacomplex): - raise TypeError( - f"Megacomplex type {name}(megacomplex_type) is not a subclass of Megacomplex" - ) - MetaMegacomplex.add_type(name, megacomplex_type) - - if default_megacomplex_type is None: - default_megacomplex_type = next(iter(megacomplex_types.keys())) - setattr(MetaMegacomplex, "_glotaran_model_attribute_default_type", default_megacomplex_type) - return MetaMegacomplex - - -def _create_property_for_attribute(cls, name, attribute_type): - - return_type = ( - Dict[str, attribute_type] - if hasattr(attribute_type, "_glotaran_has_label") - else List[attribute_type] - ) - - doc_type = "dictionary" if hasattr(attribute_type, "_glotaran_has_label") else "list" - - @property - @wrap_func_as_method(cls, name=f"{name}", annotations={"return": return_type}) - def attribute(self) -> return_type: - f"""A {doc_type} containing {type.__name__}""" - return getattr(self, f"_{name}") - - return attribute - - -def _set_constraints_functions( - cls, has_matrix_constraints_function, constrain_matrix_function, retrieve_clp_function -): - if has_matrix_constraints_function: - if not constrain_matrix_function: - raise ValueError( - "Model implements `has_matrix_constraints_function` " - "but not `constrain_matrix_function`" - ) - if not retrieve_clp_function: - raise ValueError( - "Model implements `has_matrix_constraints_function` " - "but not `retrieve_clp_function`" - ) - has_c_mat = wrap_func_as_method(cls, name="has_matrix_constraints_function")( - has_matrix_constraints_function - ) - c_mat = wrap_func_as_method(cls, name="constrain_matrix_function")( - constrain_matrix_function - ) - r_clp = wrap_func_as_method(cls, name="retrieve_clp_function")(retrieve_clp_function) - setattr(cls, "has_matrix_constraints_function", has_c_mat) - setattr(cls, "constrain_matrix_function", c_mat) - setattr(cls, "retrieve_clp_function", r_clp) - else: - setattr(cls, "has_matrix_constraints_function", None) - setattr(cls, "constrain_matrix_function", None) - setattr(cls, "retrieve_clp_function", None) - - -def _set_additional_penalty_functions( - cls, has_additional_penalty_function, additional_penalty_function -): - if has_additional_penalty_function: - if not additional_penalty_function: - raise ValueError( - "Model implements `has_additional_penalty_function`" - "but not `additional_penalty_function`" - ) - has_pen = wrap_func_as_method(cls, name="has_additional_penalty_function")( - has_additional_penalty_function - ) - pen = wrap_func_as_method(cls, name="additional_penalty_function")( - additional_penalty_function - ) - setattr(cls, "additional_penalty_function", pen) - setattr(cls, "has_additional_penalty_function", has_pen) - else: - setattr(cls, "has_additional_penalty_function", None) - setattr(cls, "additional_penalty_function", None) - - -def _set_grouped_and_indexdependent(cls, grouped, index_dependent): - setattr( - cls, - "grouped", - grouped if callable(grouped) else lambda model: grouped, - ) - - @deprecate( - deprecated_qual_name_usage="glotaran.model.base_model.Model.index_dependent", - new_qual_name_usage=("glotaran.model.megacomplex.Megacomplex.index_dependent"), - to_be_removed_in_version="0.6.0", - importable_indices=(2, 2), - ) - def idep(self): - return index_dependent(self) if callable(index_dependent) else index_dependent - - setattr(cls, "index_dependent", idep) - - # TODO: This is temporary - if callable(index_dependent): - setattr(cls, "overwrite_index_dependent", index_dependent) - - -def _set_dimensions(cls, model_type, model_dimension, global_dimension): - @property - @deprecate( - deprecated_qual_name_usage="glotaran.model.base_model.Model.model_dimension", - new_qual_name_usage=( - "glotaran.model.dataset_descriptor.DatasetDescriptor.get_model_dimension" - ), - to_be_removed_in_version="0.6.0", - importable_indices=(2, 2), - ) - def mdim(self): - return model_dimension - - if model_dimension is None: - raise ValueError(f"Model dimension not specified for model {model_type}") - setattr(cls, "model_dimension", mdim) - - @property - @deprecate( - deprecated_qual_name_usage="glotaran.model.base_model.Model.global_dimension", - new_qual_name_usage=( - "glotaran.model.dataset_descriptor.DatasetDescriptor.get_global_dimension" - ), - to_be_removed_in_version="0.6.0", - importable_indices=(2, 2), - ) - def gdim(self): - return global_dimension - - if global_dimension is None: - raise ValueError(f"Global dimension not specified for model {model_type}") - setattr(cls, "global_dimension", gdim) diff --git a/glotaran/model/interval_property.py b/glotaran/model/interval_property.py index e1d3c6888..78b7d5e40 100644 --- a/glotaran/model/interval_property.py +++ b/glotaran/model/interval_property.py @@ -5,14 +5,14 @@ from typing import List from typing import Tuple -from glotaran.model import model_attribute +from glotaran.model.item import model_item -@model_attribute( +@model_item( properties={ "interval": {"type": List[Tuple[Any, Any]], "default": None, "allow_none": True}, }, - no_label=True, + has_label=False, ) class IntervalProperty: """Applies a relation between clps as diff --git a/glotaran/model/attribute.py b/glotaran/model/item.py similarity index 78% rename from glotaran/model/attribute.py rename to glotaran/model/item.py index 05024fb13..82d1c9519 100644 --- a/glotaran/model/attribute.py +++ b/glotaran/model/item.py @@ -1,8 +1,11 @@ -"""The model attribute decorator.""" +"""The model item decorator.""" from __future__ import annotations import copy from typing import TYPE_CHECKING +from typing import Callable +from typing import List +from typing import Type from glotaran.model.property import ModelProperty from glotaran.model.util import wrap_func_as_method @@ -10,18 +13,27 @@ if TYPE_CHECKING: from typing import Any - from typing import Callable - from glotaran.model.base_model import Model + from glotaran.model.model import Model from glotaran.parameter import ParameterGroup + Validator = Callable[ + [Type[object], Type[Model]], + List[str], + ] -def model_attribute( + ValidatorParameter = Callable[ + [Type[object], Type[Model], Type[ParameterGroup]], + List[str], + ] + + +def model_item( properties: Any | dict[str, dict[str, Any]] = {}, has_type: bool = False, - no_label: bool = False, + has_label: bool = True, ) -> Callable: - """The `@model_attribute` decorator adds the given properties to the class. Further it adds + """The `@model_item` decorator adds the given properties to the class. Further it adds classmethods for deserialization, validation and printing. By default, a `label` property is added. @@ -34,28 +46,28 @@ def model_attribute( * default: a default value (optional) * allow_none: if `True`, the property can be set to None (optional) - Classes with the `model_attribute` decorator intended to be used in glotaran models. + Classes with the `model_item` decorator intended to be used in glotaran models. Parameters ---------- properties : A dictionary of property names and options. - has_type: + has_type : If true, a type property will added. Used for model attributes, which can have more then one type. - no_label: - If true no label property will be added. + has_label : + If false no label property will be added. """ def decorator(cls): - setattr(cls, "_glotaran_has_label", not no_label) - setattr(cls, "_glotaran_model_attribute", True) + setattr(cls, "_glotaran_has_label", has_label) + setattr(cls, "_glotaran_model_item", True) # store for later sanity checking if not hasattr(cls, "_glotaran_properties"): setattr(cls, "_glotaran_properties", []) - if not no_label: + if has_label: doc = f"The label of {cls.__name__} item." prop = ModelProperty(cls, "label", str, doc, None, False) setattr(cls, "label", prop) @@ -88,6 +100,9 @@ def decorator(cls): if name not in getattr(cls, "_glotaran_properties"): getattr(cls, "_glotaran_properties").append(name) + validators = _get_validators(cls) + setattr(cls, "_glotaran_validators", validators) + init = _create_init_func(cls) setattr(cls, "__init__", init) @@ -117,12 +132,13 @@ def decorator(cls): return decorator -def model_attribute_typed( +def model_item_typed( + *, types: dict[str, Any], - no_label=False, + has_label: bool = True, default_type: str = None, ): - """The model_attribute_typed decorator adds attributes to the class to enable + """The model_item_typed decorator adds attributes to the class to enable the glotaran model parser to infer the correct class for an item when there are multiple variants. @@ -130,16 +146,16 @@ def model_attribute_typed( ---------- types : A dictionary of types and options. - no_label: - If `True` no label property will be added. + has_label: + If `False` no label property will be added. """ def decorator(cls): - setattr(cls, "_glotaran_model_attribute", True) - setattr(cls, "_glotaran_model_attribute_typed", True) - setattr(cls, "_glotaran_model_attribute_types", types) - setattr(cls, "_glotaran_model_attribute_default_type", default_type) + setattr(cls, "_glotaran_model_item", True) + setattr(cls, "_glotaran_model_item_typed", True) + setattr(cls, "_glotaran_model_item_types", types) + setattr(cls, "_glotaran_model_item_default_type", default_type) get_default_type = _create_get_default_type_func(cls) setattr(cls, "get_default_type", get_default_type) @@ -147,18 +163,36 @@ def decorator(cls): add_type = _create_add_type_func(cls) setattr(cls, "add_type", add_type) - setattr(cls, "_glotaran_has_label", not no_label) + setattr(cls, "_glotaran_has_label", has_label) return cls return decorator +def model_item_validator(need_parameter: bool): + """The model_item_validator marks a method of a model item as validation function""" + + def decorator(method: Validator | ValidatorParameter): + setattr(method, "_glotaran_validator", need_parameter) + return method + + return decorator + + +def _get_validators(cls): + return { + method: getattr(getattr(cls, method), "_glotaran_validator") + for method in dir(cls) + if hasattr(getattr(cls, method), "_glotaran_validator") + } + + def _create_get_default_type_func(cls): @classmethod @wrap_func_as_method(cls) def get_default_type(cls) -> str: - return getattr(cls, "_glotaran_model_attribute_default_type") + return getattr(cls, "_glotaran_model_item_default_type") return get_default_type @@ -167,7 +201,7 @@ def _create_add_type_func(cls): @classmethod @wrap_func_as_method(cls) def add_type(cls, type_name: str, attribute_type: type): - getattr(cls, "_glotaran_model_attribute_types")[type_name] = attribute_type + getattr(cls, "_glotaran_model_item_types")[type_name] = attribute_type return add_type @@ -220,7 +254,7 @@ def from_list(ncls, values: list) -> cls: A list of values. """ item = ncls() - if len(values) is not len(ncls._glotaran_properties): + if len(values) != len(ncls._glotaran_properties): raise ValueError( f"To few or much parameters for '{ncls.__name__}'" f"\nGot: {values}\nWant: {ncls._glotaran_properties}" @@ -236,7 +270,7 @@ def from_list(ncls, values: list) -> cls: def _create_validation_func(cls): @wrap_func_as_method(cls) - def validate(self, model: Model, parameters=None) -> list[str]: + def validate(self, model: Model, parameters: ParameterGroup | None = None) -> list[str]: f"""Creates a list of parameters needed by this instance of {cls.__name__} not present in a set of parameters. @@ -249,12 +283,19 @@ def validate(self, model: Model, parameters=None) -> list[str]: missing : A list the missing will be appended to. """ - errors = [] + problems = [] for name in self._glotaran_properties: prop = getattr(self.__class__, name) value = getattr(self, name) - errors += prop.validate(value, model, parameters) - return errors + problems += prop.validate(value, model, parameters) + for validator, need_parameter in self._glotaran_validators.items(): + if need_parameter: + if parameters is not None: + problems += getattr(self, validator)(model, parameters) + else: + problems += getattr(self, validator)(model) + + return problems return validate diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 906598bf8..3e568c6ad 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -1,31 +1,100 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Dict +from typing import List import xarray as xr +from typing_inspect import get_args +from typing_inspect import is_generic_type -from glotaran.model import DatasetDescriptor -from glotaran.model import model_attribute +from glotaran.model.item import model_item +from glotaran.model.item import model_item_typed +from glotaran.plugin_system.megacomplex_registration import register_megacomplex if TYPE_CHECKING: from typing import Any + from glotaran.model import DatasetModel + + +def create_model_megacomplex_type( + megacomplex_types: dict[str, Megacomplex], default_type: str = None +) -> type: + @model_item_typed(types=megacomplex_types, default_type=default_type) + class ModelMegacomplex: + """This class holds all Megacomplex types defined by a model.""" + + return ModelMegacomplex + def megacomplex( - dimension: str, + *, + dimension: str | None = None, + model_items: dict[str, dict[str, Any]] = None, properties: Any | dict[str, dict[str, Any]] = None, - attributes: dict[str, dict[str, Any]] = None, - dataset_attributes: dict[str, dict[str, Any]] = None, + dataset_model_items: dict[str, dict[str, Any]] = None, + dataset_properties: Any | dict[str, dict[str, Any]] = None, + unique: bool = False, + register_as: str | None = None, ): """The `@megacomplex` decorator is intended to be used on subclasses of :class:`glotaran.model.Megacomplex`. It registers the megacomplex model and makes it available in analysis models. """ - - # TODO: this is temporary and will change in follow up PR properties = properties if properties is not None else {} - properties["dimension"] = {"type": str, "default": dimension} - return model_attribute(properties=properties, has_type=True) + properties["dimension"] = {"type": str} + if dimension is not None: + properties["dimension"]["default"] = dimension + + if model_items is None: + model_items = {} + else: + model_items, properties = _add_model_items_to_properties(model_items, properties) + + dataset_properties = dataset_properties if dataset_properties is not None else {} + if dataset_model_items is None: + dataset_model_items = {} + else: + dataset_model_items, dataset_properties = _add_model_items_to_properties( + dataset_model_items, dataset_properties + ) + + def decorator(cls): + + setattr(cls, "_glotaran_megacomplex_model_items", model_items) + setattr(cls, "_glotaran_megacomplex_dataset_model_items", dataset_model_items) + setattr(cls, "_glotaran_megacomplex_dataset_properties", dataset_properties) + setattr(cls, "_glotaran_megacomplex_unique", unique) + + megacomplex_type = model_item(properties=properties, has_type=True)(cls) + + if register_as is not None: + register_megacomplex(register_as, megacomplex_type) + + return megacomplex_type + + return decorator + + +def _add_model_items_to_properties(model_items: dict, properties: dict) -> tuple[dict, dict]: + for name, item in model_items.items(): + item_type = item["type"] if isinstance(item, dict) else item + property_type = str + + if is_generic_type(item_type): + if item_type._name == "List": + property_type = List[str] + item_type = get_args(item_type)[0] + elif item_type._name == "Dict": + property_type = Dict[str, str] + item_type = get_args(item_type)[1] + + property_dict = item.copy() if isinstance(item, dict) else {} + property_dict["type"] = property_type + properties[name] = property_dict + model_items[name] = item_type + return model_items, properties class Megacomplex: @@ -37,11 +106,30 @@ class Megacomplex: def calculate_matrix( self, - dataset_model: DatasetDescriptor, + dataset_model: DatasetModel, indices: dict[str, int], **kwargs, ) -> xr.DataArray: raise NotImplementedError - def index_dependent(self, dataset: DatasetDescriptor) -> bool: + def index_dependent(self, dataset_model: DatasetModel) -> bool: raise NotImplementedError + + def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + raise NotImplementedError + + @classmethod + def glotaran_model_items(cls) -> str: + return cls._glotaran_megacomplex_model_items + + @classmethod + def glotaran_dataset_model_items(cls) -> str: + return cls._glotaran_megacomplex_dataset_model_items + + @classmethod + def glotaran_dataset_properties(cls) -> str: + return cls._glotaran_megacomplex_dataset_properties + + @classmethod + def glotaran_unique(cls) -> bool: + return cls._glotaran_megacomplex_unique diff --git a/glotaran/model/model.py b/glotaran/model/model.py new file mode 100644 index 000000000..b84427142 --- /dev/null +++ b/glotaran/model/model.py @@ -0,0 +1,332 @@ +"""A base class for global analysis models.""" +from __future__ import annotations + +import copy +from typing import List +from warnings import warn + +from glotaran.model.clp_penalties import EqualAreaPenalty +from glotaran.model.constraint import Constraint +from glotaran.model.dataset_model import create_dataset_model_type +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import create_model_megacomplex_type +from glotaran.model.relation import Relation +from glotaran.model.util import ModelError +from glotaran.model.weight import Weight +from glotaran.parameter import Parameter +from glotaran.parameter import ParameterGroup +from glotaran.utils.ipython import MarkdownStr + +default_model_items = { + "clp_area_penalties": EqualAreaPenalty, + "constraints": Constraint, + "relations": Relation, + "weights": Weight, +} + +default_dataset_properties = { + "megacomplex": List[str], + "megacomplex_scale": {"type": List[Parameter], "allow_none": True}, + "global_megacomplex": {"type": List[str], "default": []}, + "global_megacomplex_scale": {"type": List[Parameter], "default": None, "allow_none": True}, + "scale": {"type": Parameter, "default": None, "allow_none": True}, +} + + +class Model: + """A base class for global analysis models.""" + + def __init__( + self, + *, + megacomplex_types: dict[str, type[Megacomplex]], + default_megacomplex_type: str | None = None, + ): + self._megacomplex_types = megacomplex_types + self._default_megacomplex_type = default_megacomplex_type or next(iter(megacomplex_types)) + + self._model_items = {} + self._dataset_properties = {} + self._add_default_items_and_properties() + self._add_megacomplexe_types() + self._add_dataset_type() + + @classmethod + def from_dict( + cls, + model_dict_ref: dict, + *, + megacomplex_types: dict[str, type[Megacomplex]], + default_megacomplex_type: str | None = None, + ) -> Model: + """Creates a model from a dictionary. + + Parameters + ---------- + model_dict : + Dictionary containing the model. + """ + + model = cls( + megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type + ) + + model_dict = copy.deepcopy(model_dict_ref) + + # iterate over items + for name, items in list(model_dict.items()): + + if name not in model._model_items: + warn(f"Unknown model item type '{name}'.") + continue + + is_list = isinstance(getattr(model, name), list) + + if is_list: + model._add_list_items(name, items) + else: + model._add_dict_items(name, items) + + return model + + def _add_dict_items(self, name: str, items: dict): + + for label, item in items.items(): + item_cls = self._model_items[name] + is_typed = hasattr(item_cls, "_glotaran_model_item_typed") + if is_typed: + if "type" not in item and item_cls.get_default_type() is None: + raise ValueError(f"Missing type for attribute '{name}'") + item_type = item.get("type", item_cls.get_default_type()) + + types = item_cls._glotaran_model_item_types + if item_type not in types: + raise ValueError(f"Unknown type '{item_type}' for attribute '{name}'") + item_cls = types[item_type] + item["label"] = label + item = item_cls.from_dict(item) + getattr(self, name)[label] = item + + def _add_list_items(self, name: str, items: list): + + for item in items: + item_cls = self._model_items[name] + is_typed = hasattr(item_cls, "_glotaran_model_item_typed") + if is_typed: + if "type" not in item: + raise ValueError(f"Missing type for attribute '{name}'") + item_type = item["type"] + + if item_type not in item_cls._glotaran_model_item_types: + raise ValueError(f"Unknown type '{item_type}' for attribute '{name}'") + item_cls = item_cls._glotaran_model_item_types[item_type] + item = item_cls.from_dict(item) + getattr(self, name).append(item) + + def _add_megacomplexe_types(self): + + for name, megacomplex_type in self._megacomplex_types.items(): + if not issubclass(megacomplex_type, Megacomplex): + raise TypeError( + f"Megacomplex type {name}({megacomplex_type}) is not a subclass of Megacomplex" + ) + self._add_megacomplex_type(megacomplex_type) + + model_megacomplex_type = create_model_megacomplex_type( + self._megacomplex_types, self.default_megacomplex + ) + self._add_model_item("megacomplex", model_megacomplex_type) + + def _add_megacomplex_type(self, megacomplex_type: type[Megacomplex]): + + for name, item in megacomplex_type.glotaran_model_items().items(): + self._add_model_item(name, item) + + for name, item in megacomplex_type.glotaran_dataset_model_items().items(): + self._add_model_item(name, item) + + for name, prop in megacomplex_type.glotaran_dataset_properties().items(): + self._add_dataset_property(name, prop) + + def _add_model_item(self, name: str, item: type): + if name in self._model_items: + if self._model_items[name] != item: + raise ModelError( + f"Cannot add item of type {name}. Model item '{name}' was already defined" + "as a different type." + ) + return + self._model_items[name] = item + + if getattr(item, "_glotaran_has_label"): + setattr(self, f"{name}", {}) + else: + setattr(self, f"{name}", []) + + def _add_dataset_property(self, name: str, dataset_property: dict[str, any]): + if name in self._dataset_properties: + known_type = ( + self._dataset_properties[name] + if not isinstance(self._dataset_properties, dict) + else self._dataset_properties[name]["type"] + ) + new_type = ( + dataset_property + if not isinstance(dataset_property, dict) + else dataset_property["type"] + ) + if known_type != new_type: + raise ModelError( + f"Cannot add dataset property of type {name} as it was already defined" + "as a different type." + ) + return + self._dataset_properties[name] = dataset_property + + def _add_default_items_and_properties(self): + for name, item in default_model_items.items(): + self._add_model_item(name, item) + + for name, prop in default_dataset_properties.items(): + self._add_dataset_property(name, prop) + + def _add_dataset_type(self): + dataset_model_type = create_dataset_model_type(self._dataset_properties) + self._add_model_item("dataset", dataset_model_type) + + @property + def default_megacomplex(self) -> str: + """The default megacomplex used by this model.""" + return self._default_megacomplex_type + + @property + def megacomplex_types(self) -> dict[str, type[Megacomplex]]: + """The megacomplex types used by this model.""" + return self._megacomplex_types + + @property + def model_items(self) -> dict[str, type[object]]: + """The model_items types used by this model.""" + return self._model_items + + @property + def global_megacomplex(self) -> dict[str, Megacomplex]: + """Alias for `glotaran.model.megacomplex`. Needed internally.""" + return self.megacomplex + + def need_index_dependent(self): + """Returns true if e.g. relations with intervals are present.""" + return any(i.interval is not None for i in self.constraints + self.relations) + + def problem_list(self, parameters: ParameterGroup = None) -> list[str]: + """ + Returns a list with all problems in the model and missing parameters if specified. + + Parameters + ---------- + + parameter : + The parameter to validate. + """ + problems = [] + + for name in self._model_items: + items = getattr(self, name) + if isinstance(items, list): + for item in items: + problems += item.validate(self, parameters=parameters) + else: + for _, item in items.items(): + problems += item.validate(self, parameters=parameters) + + return problems + + def validate(self, parameters: ParameterGroup = None, raise_exception: bool = False) -> str: + """ + Returns a string listing all problems in the model and missing parameters if specified. + + Parameters + ---------- + + parameter : + The parameter to validate. + """ + result = "" + + problems = self.problem_list(parameters) + if problems: + result = f"Your model has {len(problems)} problems:\n" + for p in problems: + result += f"\n * {p}" + if raise_exception: + raise ModelError(result) + else: + result = "Your model is valid." + return result + + def valid(self, parameters: ParameterGroup = None) -> bool: + """Returns `True` if the number problems in the model is 0, else `False` + + Parameters + ---------- + + parameter : + The parameter to validate. + """ + return len(self.problem_list(parameters)) == 0 + + def markdown( + self, + parameters: ParameterGroup = None, + initial_parameters: ParameterGroup = None, + base_heading_level: int = 1, + ) -> MarkdownStr: + """Formats the model as Markdown string. + + Parameters will be included if specified. + + Parameters + ---------- + parameter: ParameterGroup + Parameter to include. + initial_parameters: ParameterGroup + Initial values for the parameters. + base_heading_level: int + Base heading level of the markdown sections. + + E.g.: + + - If it is 1 the string will start with '# Model'. + - If it is 3 the string will start with '### Model'. + """ + base_heading = "#" * base_heading_level + string = f"{base_heading} Model\n\n" + string += "_Megacomplex Types_: " + string += ", ".join(self._megacomplex_types) + string += "\n\n" + + for name in self._model_items: + items = getattr(self, name) + if not items: + continue + + string += f"{base_heading}# {name.replace('_', ' ').title()}\n\n" + + if isinstance(items, dict): + items = items.values() + for item in items: + item_str = item.mprint( + parameters=parameters, initial_parameters=initial_parameters + ).split("\n") + string += f"* {item_str[0]}\n" + for s in item_str[1:]: + string += f" {s}\n" + string += "\n" + return MarkdownStr(string) + + def _repr_markdown_(self) -> str: + """Special method used by ``ipython`` to render markdown.""" + return str(self.markdown(base_heading_level=3)) + + def __str__(self): + return str(self.markdown()) diff --git a/glotaran/model/property.py b/glotaran/model/property.py index c1bb47a9f..92518badc 100644 --- a/glotaran/model/property.py +++ b/glotaran/model/property.py @@ -50,18 +50,18 @@ def validate(self, value, model, parameters=None) -> typing.List[str]: return [] missing_model = [] - if hasattr(model, f"set_{self._name}") or hasattr(model, f"add_{self._name}"): - attr = getattr(model, self._name) + if self._name in model.model_items: + items = getattr(model, self._name) if isinstance(value, list): for item in value: - if item not in attr: + if item not in items: missing_model.append((self._name, item)) elif isinstance(value, dict): for item in value.values(): - if item not in attr: + if item not in items: missing_model.append((self._name, item)) - elif value not in attr: + elif value not in items: missing_model.append((self._name, value)) missing_model = [ f"Missing Model Item: '{name}'['{label}']" for name, label in missing_model diff --git a/glotaran/model/relation.py b/glotaran/model/relation.py index b5c210d09..d843829eb 100644 --- a/glotaran/model/relation.py +++ b/glotaran/model/relation.py @@ -1,18 +1,18 @@ """ Glotaran Relation """ from __future__ import annotations -from glotaran.model import model_attribute from glotaran.model.interval_property import IntervalProperty +from glotaran.model.item import model_item from glotaran.parameter import Parameter -@model_attribute( +@model_item( properties={ "source": str, "target": str, "parameter": Parameter, }, - no_label=True, + has_label=False, ) class Relation(IntervalProperty): """Applies a relation between clps as diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index 0d9729c5b..cc862b6ea 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -6,16 +6,20 @@ import xarray as xr from IPython.core.formatters import format_display_data +from glotaran.model import DatasetModel from glotaran.model import Megacomplex -from glotaran.model import Model from glotaran.model import megacomplex -from glotaran.model import model -from glotaran.model import model_attribute +from glotaran.model import model_item +from glotaran.model.clp_penalties import EqualAreaPenalty +from glotaran.model.constraint import Constraint +from glotaran.model.model import Model +from glotaran.model.relation import Relation +from glotaran.model.weight import Weight from glotaran.parameter import Parameter from glotaran.parameter import ParameterGroup -@model_attribute( +@model_item( properties={ "param": Parameter, "megacomplex": str, @@ -24,38 +28,59 @@ "complex": {"type": Dict[Tuple[str, str], Parameter]}, }, ) -class MockAttr: +class MockItem: pass -@megacomplex("model") -class MockMegacomplex(Megacomplex): +@model_item(has_label=False) +class MockItemNoLabel: pass -@megacomplex("model") +@megacomplex(dimension="model", model_items={"test_item1": {"type": MockItem, "allow_none": True}}) +class MockMegacomplex1(Megacomplex): + pass + + +@megacomplex(dimension="model", model_items={"test_item2": MockItemNoLabel}) class MockMegacomplex2(Megacomplex): pass -@model( - "mock_model", - attributes={"test": MockAttr}, - megacomplex_types={ - "mock_megacomplex": MockMegacomplex, - "mock_megacomplex2": MockMegacomplex2, +@megacomplex(model_items={"test_item3": List[MockItem]}) +class MockMegacomplex3(Megacomplex): + pass + + +@megacomplex(dimension="model", model_items={"test_item4": Dict[str, MockItem]}) +class MockMegacomplex4(Megacomplex): + pass + + +@megacomplex( + dimension="model", + dataset_model_items={"test_item_dataset": MockItem}, + dataset_properties={ + "test_property_dataset1": int, + "test_property_dataset2": {"type": Parameter}, }, - model_dimension="model", - global_dimension="global", ) -class MockModel(Model): +class MockMegacomplex5(Megacomplex): + pass + + +@megacomplex(dimension="model", unique=True) +class MockMegacomplex6(Megacomplex): pass @pytest.fixture -def mock_model(): - d = { - "megacomplex": {"m1": {}, "m2": ["mock_megacomplex2", "model2"]}, +def test_model(): + model_dict = { + "megacomplex": { + "m1": {"test_item1": "t2"}, + "m2": {"type": "type5", "dimension": "model2"}, + }, "weights": [ { "datasets": ["d1", "d2"], @@ -64,31 +89,54 @@ def mock_model(): "value": 5.4, } ], - "test": { + "test_item1": { "t1": { "param": "foo", "megacomplex": "m1", "param_list": ["bar", "baz"], "complex": {("s1", "s2"): "baz"}, }, - "t2": ["baz", "m2", ["foo"], 7, {}], + "t2": { + "param": "baz", + "megacomplex": "m2", + "param_list": ["foo"], + "complex": {}, + "default_item": 7, + }, }, "dataset": { "dataset1": { "megacomplex": ["m1"], "scale": "scale_1", + "test_item_dataset": "t1", + "test_property_dataset1": 1, + "test_property_dataset2": "bar", + }, + "dataset2": { + "megacomplex": ["m2"], + "global_megacomplex": ["m1"], + "scale": "scale_2", + "test_item_dataset": "t2", + "test_property_dataset1": 1, + "test_property_dataset2": "bar", }, - "dataset2": [["m2"], ["bar"], "scale_2"], }, } - return MockModel.from_dict(d) + model_dict["test_item_dataset"] = model_dict["test_item1"] + return Model.from_dict( + model_dict, + megacomplex_types={ + "type1": MockMegacomplex1, + "type5": MockMegacomplex5, + }, + ) @pytest.fixture def model_error(): - d = { - "megacomplex": {"m1": {}, "m2": {}}, - "test": { + model_dict = { + "megacomplex": {"m1": {}, "m2": {"type": "type2"}, "m3": {"type": "type2"}}, + "test_item1": { "t1": { "param": "fool", "megacomplex": "mX", @@ -101,10 +149,98 @@ def model_error(): "megacomplex": ["N1", "N2"], "scale": "scale_1", }, - "dataset2": [["mrX"], ["bar"], "scale_3"], + "dataset2": { + "megacomplex": ["mrX"], + "scale": "scale_3", + }, + "dataset3": { + "megacomplex": ["m2", "m3"], + }, }, } - return MockModel.from_dict(d) + return Model.from_dict( + model_dict, + megacomplex_types={ + "type1": MockMegacomplex1, + "type2": MockMegacomplex6, + }, + ) + + +def test_model_init(): + model = Model( + megacomplex_types={ + "type1": MockMegacomplex1, + "type2": MockMegacomplex2, + "type3": MockMegacomplex3, + "type4": MockMegacomplex4, + "type5": MockMegacomplex5, + } + ) + + assert model.default_megacomplex == "type1" + + assert len(model.megacomplex_types) == 5 + assert "type1" in model.megacomplex_types + assert model.megacomplex_types["type1"] == MockMegacomplex1 + assert "type2" in model.megacomplex_types + assert model.megacomplex_types["type2"] == MockMegacomplex2 + + assert hasattr(model, "test_item1") + assert isinstance(model.test_item1, dict) + assert "test_item1" in model._model_items + assert issubclass(model._model_items["test_item1"], MockItem) + + assert hasattr(model, "test_item2") + assert isinstance(model.test_item2, list) + assert "test_item2" in model._model_items + assert issubclass(model._model_items["test_item2"], MockItemNoLabel) + + assert hasattr(model, "test_item3") + assert isinstance(model.test_item3, dict) + assert "test_item3" in model._model_items + assert issubclass(model._model_items["test_item3"], MockItem) + + assert hasattr(model, "test_item4") + assert isinstance(model.test_item4, dict) + assert "test_item4" in model._model_items + assert issubclass(model._model_items["test_item4"], MockItem) + + assert hasattr(model, "test_item_dataset") + assert isinstance(model.test_item_dataset, dict) + assert "test_item_dataset" in model._model_items + assert issubclass(model._model_items["test_item_dataset"], MockItem) + assert "test_item_dataset" in model._dataset_properties + assert issubclass(model._dataset_properties["test_item_dataset"]["type"], str) + assert "test_property_dataset1" in model._dataset_properties + assert issubclass(model._dataset_properties["test_property_dataset1"], int) + assert "test_property_dataset2" in model._dataset_properties + assert issubclass(model._dataset_properties["test_property_dataset2"]["type"], Parameter) + + assert hasattr(model, "clp_area_penalties") + assert isinstance(model.clp_area_penalties, list) + assert "clp_area_penalties" in model._model_items + assert issubclass(model._model_items["clp_area_penalties"], EqualAreaPenalty) + + assert hasattr(model, "constraints") + assert isinstance(model.constraints, list) + assert "constraints" in model._model_items + assert issubclass(model._model_items["constraints"], Constraint) + + assert hasattr(model, "relations") + assert isinstance(model.relations, list) + assert "relations" in model._model_items + assert issubclass(model._model_items["relations"], Relation) + + assert hasattr(model, "weights") + assert isinstance(model.weights, list) + assert "weights" in model._model_items + assert issubclass(model._model_items["weights"], Weight) + + assert hasattr(model, "dataset") + assert isinstance(model.dataset, dict) + assert "dataset" in model._model_items + assert issubclass(model._model_items["dataset"], DatasetModel) @pytest.fixture @@ -113,99 +249,99 @@ def parameter(): return ParameterGroup.from_list(params) -def test_model_misc(mock_model: Model): - assert mock_model.model_type == "mock_model" - assert isinstance(mock_model.megacomplex["m1"], MockMegacomplex) - assert isinstance(mock_model.megacomplex["m2"], MockMegacomplex2) - assert mock_model.megacomplex["m1"].dimension == "model" - assert mock_model.megacomplex["m2"].dimension == "model2" - - -@pytest.mark.parametrize("attr", ["dataset", "megacomplex", "weights", "test"]) -def test_model_attr(mock_model: Model, attr: str): - assert hasattr(mock_model, attr) - if attr != "weights": - assert hasattr(mock_model, f"get_{attr}") - assert hasattr(mock_model, f"set_{attr}") - else: - assert hasattr(mock_model, f"add_{attr}") +def test_model_misc(test_model: Model): + assert isinstance(test_model.megacomplex["m1"], MockMegacomplex1) + assert isinstance(test_model.megacomplex["m2"], MockMegacomplex5) + assert test_model.megacomplex["m1"].dimension == "model" + assert test_model.megacomplex["m2"].dimension == "model2" -def test_model_validity(mock_model: Model, model_error: Model, parameter: ParameterGroup): - print(mock_model.test["t1"]) - print(mock_model.problem_list()) - print(mock_model.problem_list(parameter)) - assert mock_model.valid() - assert mock_model.valid(parameter) +def test_model_validity(test_model: Model, model_error: Model, parameter: ParameterGroup): + print(test_model.test_item1["t1"]) + print(test_model.problem_list()) + print(test_model.problem_list(parameter)) + assert test_model.valid() + assert test_model.valid(parameter) print(model_error.problem_list()) print(model_error.problem_list(parameter)) assert not model_error.valid() - assert len(model_error.problem_list()) == 4 + assert len(model_error.problem_list()) == 5 assert not model_error.valid(parameter) - assert len(model_error.problem_list(parameter)) == 8 + assert len(model_error.problem_list(parameter)) == 9 -def test_items(mock_model: Model): +def test_items(test_model: Model): - assert "m1" in mock_model.megacomplex - assert "m2" in mock_model.megacomplex + assert "m1" in test_model.megacomplex + assert "m2" in test_model.megacomplex - assert "t1" in mock_model.test - t = mock_model.get_test("t1") + assert "t1" in test_model.test_item1 + t = test_model.test_item1.get("t1") assert t.param.full_label == "foo" assert t.megacomplex == "m1" assert [p.full_label for p in t.param_list] == ["bar", "baz"] assert t.default_item == 42 assert ("s1", "s2") in t.complex assert t.complex[("s1", "s2")].full_label == "baz" - assert "t2" in mock_model.test - t = mock_model.get_test("t2") + assert "t2" in test_model.test_item1 + t = test_model.test_item1.get("t2") assert t.param.full_label == "baz" assert t.megacomplex == "m2" assert [p.full_label for p in t.param_list] == ["foo"] assert t.default_item == 7 assert t.complex == {} - assert "dataset1" in mock_model.dataset - assert mock_model.get_dataset("dataset1").megacomplex == ["m1"] - assert mock_model.get_dataset("dataset1").scale.full_label == "scale_1" + assert "dataset1" in test_model.dataset + assert test_model.dataset.get("dataset1").megacomplex == ["m1"] + assert test_model.dataset.get("dataset1").scale.full_label == "scale_1" - assert "dataset2" in mock_model.dataset - assert mock_model.get_dataset("dataset2").megacomplex == ["m2"] - assert mock_model.get_dataset("dataset2").scale.full_label == "scale_2" + assert "dataset2" in test_model.dataset + assert test_model.dataset.get("dataset2").megacomplex == ["m2"] + assert test_model.dataset.get("dataset2").global_megacomplex == ["m1"] + assert test_model.dataset.get("dataset2").scale.full_label == "scale_2" - assert len(mock_model.weights) == 1 - w = mock_model.weights[0] + assert len(test_model.weights) == 1 + w = test_model.weights[0] assert w.datasets == ["d1", "d2"] assert w.global_interval == (1, 4) assert w.model_interval == (2, 3) assert w.value == 5.4 -def test_fill(mock_model: Model, parameter: ParameterGroup): +def test_fill(test_model: Model, parameter: ParameterGroup): data = xr.DataArray([[1]], dims=("global", "model")).to_dataset(name="data") - dataset = mock_model.get_dataset("dataset1").fill(mock_model, parameter) + dataset = test_model.dataset.get("dataset1").fill(test_model, parameter) dataset.set_data(data) assert [cmplx.label for cmplx in dataset.megacomplex] == ["m1"] assert dataset.scale == 2 + + assert dataset.get_model_dimension() == "model" + assert dataset.get_global_dimension() == "global" + dataset.swap_dimensions() + assert dataset.get_model_dimension() == "global" + assert dataset.get_global_dimension() == "model" + dataset.swap_dimensions() assert dataset.get_model_dimension() == "model" assert dataset.get_global_dimension() == "global" - data = xr.DataArray([[1]], dims=("global2", "model2")).to_dataset(name="data") - dataset = mock_model.get_dataset("dataset2").fill(mock_model, parameter) + assert not dataset.global_model() + + dataset = test_model.dataset.get("dataset2").fill(test_model, parameter) assert [cmplx.label for cmplx in dataset.megacomplex] == ["m2"] assert dataset.scale == 8 - dataset.set_data(data) assert dataset.get_model_dimension() == "model2" - assert dataset.get_global_dimension() == "global2" + assert dataset.get_global_dimension() == "model" + + assert dataset.global_model() + assert [cmplx.label for cmplx in dataset.global_megacomplex] == ["m1"] - t = mock_model.get_test("t1").fill(mock_model, parameter) + t = test_model.test_item1.get("t1").fill(test_model, parameter) assert t.param == 3 assert t.megacomplex.label == "m1" assert t.param_list == [4, 2] assert t.default_item == 42 assert t.complex == {("s1", "s2"): 2} - t = mock_model.get_test("t2").fill(mock_model, parameter) + t = test_model.test_item1.get("t2").fill(test_model, parameter) assert t.param == 2 assert t.megacomplex.label == "m2" assert t.param_list == [3] @@ -213,22 +349,22 @@ def test_fill(mock_model: Model, parameter: ParameterGroup): assert t.complex == {} -def test_model_markdown_base_heading_level(mock_model: Model): +def test_model_markdown_base_heading_level(test_model: Model): """base_heading_level applies to all sections.""" - assert mock_model.markdown().startswith("# Model") - assert "## Test" in mock_model.markdown() - assert mock_model.markdown(base_heading_level=3).startswith("### Model") - assert "#### Test" in mock_model.markdown(base_heading_level=3) + assert test_model.markdown().startswith("# Model") + assert "## Test" in test_model.markdown() + assert test_model.markdown(base_heading_level=3).startswith("### Model") + assert "#### Test" in test_model.markdown(base_heading_level=3) -def test_model_ipython_rendering(mock_model: Model): +def test_model_ipython_rendering(test_model: Model): """Autorendering in ipython""" - rendered_obj = format_display_data(mock_model)[0] + rendered_obj = format_display_data(test_model)[0] assert "text/markdown" in rendered_obj assert rendered_obj["text/markdown"].startswith("### Model") - rendered_markdown_return = format_display_data(mock_model.markdown())[0] + rendered_markdown_return = format_display_data(test_model.markdown())[0] assert "text/markdown" in rendered_markdown_return assert rendered_markdown_return["text/markdown"].startswith("# Model") diff --git a/glotaran/model/weight.py b/glotaran/model/weight.py index 613ea33e1..18408c467 100644 --- a/glotaran/model/weight.py +++ b/glotaran/model/weight.py @@ -3,10 +3,10 @@ from typing import List from typing import Tuple -from glotaran.model.attribute import model_attribute +from glotaran.model.item import model_item -@model_attribute( +@model_item( properties={ "datasets": {type: List[str]}, "global_interval": { @@ -21,7 +21,7 @@ }, "value": {"type": float}, }, - no_label=True, + has_label=False, ) class Weight: """The `Weight` class describes a value by which a dataset will scaled. diff --git a/glotaran/plugin_system/base_registry.py b/glotaran/plugin_system/base_registry.py index 246f0ffad..363236ebf 100644 --- a/glotaran/plugin_system/base_registry.py +++ b/glotaran/plugin_system/base_registry.py @@ -21,9 +21,9 @@ from glotaran.io.interface import DataIoInterface from glotaran.io.interface import ProjectIoInterface - from glotaran.model.base_model import Model + from glotaran.model.megacomplex import Megacomplex - _PluginType = TypeVar("_PluginType", Type[Model], DataIoInterface, ProjectIoInterface) + _PluginType = TypeVar("_PluginType", Type[Megacomplex], DataIoInterface, ProjectIoInterface) _PluginInstantiableType = TypeVar( "_PluginInstantiableType", DataIoInterface, ProjectIoInterface ) @@ -37,7 +37,7 @@ class __PluginRegistry: This is super private since if anyone messes with it, the pluginsystem could break. """ - model: MutableMapping[str, type[Model]] = {} + megacomplex: MutableMapping[str, type[Megacomplex]] = {} data_io: MutableMapping[str, DataIoInterface] = {} project_io: MutableMapping[str, ProjectIoInterface] = {} @@ -115,7 +115,7 @@ def load_plugins(): Currently used builtin entrypoints are: - ``glotaran.plugins.data_io`` - - ``glotaran.plugins.model`` + - ``glotaran.plugins.megacomplex`` - ``glotaran.plugins.project_io`` """ if "DEACTIVATE_GTA_PLUGINS" not in os.environ: # pragma: no branch @@ -364,7 +364,7 @@ def get_method_from_plugin( plugin : object | type[object], Plugin instance or class. method_name : str - Method name, e.g. load_model. + Method name, e.g. load_megacomplex. Returns ------- @@ -402,7 +402,7 @@ def show_method_help( plugin : object | type[object], Plugin instance or class. method_name : str - Method name, e.g. load_model. + Method name, e.g. load_megacomplex. """ method = get_method_from_plugin(plugin, method_name) help(method) diff --git a/glotaran/plugin_system/megacomplex_registration.py b/glotaran/plugin_system/megacomplex_registration.py new file mode 100644 index 000000000..2f31c50b1 --- /dev/null +++ b/glotaran/plugin_system/megacomplex_registration.py @@ -0,0 +1,157 @@ +"""Megacomplex registration convenience functions.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tabulate import tabulate + +from glotaran.plugin_system.base_registry import __PluginRegistry +from glotaran.plugin_system.base_registry import add_plugin_to_registry +from glotaran.plugin_system.base_registry import full_plugin_name +from glotaran.plugin_system.base_registry import get_plugin_from_registry +from glotaran.plugin_system.base_registry import is_registered_plugin +from glotaran.plugin_system.base_registry import registered_plugins +from glotaran.plugin_system.base_registry import set_plugin +from glotaran.utils.ipython import MarkdownStr + +if TYPE_CHECKING: + from glotaran.model import Megacomplex + + +def register_megacomplex(megacomplex_type: str, megacomplex: type[Megacomplex]) -> None: + """Add a megacomplex to the megacomplex registry. + + Parameters + ---------- + megacomplex_type : str + Name of the megacomplex under which it is registered. + megacomplex : type[Megacomplex] + megacomplex class to be registered. + """ + add_plugin_to_registry( + plugin_register_key=megacomplex_type, + plugin=megacomplex, + plugin_registry=__PluginRegistry.megacomplex, + plugin_set_func_name="set_megacomplex_plugin", + ) + + +def is_known_megacomplex(megacomplex_type: str) -> bool: + """Check if a megacomplex is in the megacomplex registry. + + Parameters + ---------- + megacomplex_type : str + Name of the megacomplex under which it is registered. + + Returns + ------- + bool + Whether or not the megacomplex is registered. + """ + return is_registered_plugin( + plugin_register_key=megacomplex_type, plugin_registry=__PluginRegistry.megacomplex + ) + + +def get_megacomplex(megacomplex_type: str) -> type[Megacomplex]: + """Retrieve a megacomplex from the megacomplex registry. + + Parameters + ---------- + megacomplex_type : str + Name of the megacomplex under which it is registered. + + Returns + ------- + type[Megacomplex] + Megacomplex class + """ + return get_plugin_from_registry( + plugin_register_key=megacomplex_type, + plugin_registry=__PluginRegistry.megacomplex, + not_found_error_message=( + f"Unknown megacomplex type {megacomplex_type!r}. " + f"Known megacomplex types are: {known_megacomplex_names(full_names=True)}" + ), + ) + + +def known_megacomplex_names(full_names: bool = False) -> list[str]: + """Names of the registered megacomplexs. + + Parameters + ---------- + full_names : bool + Whether to display the full names the plugins are + registered under as well. + + Returns + ------- + list[str] + List of registered megacomplexs. + """ + return registered_plugins(__PluginRegistry.megacomplex, full_names=full_names) + + +def set_megacomplex_plugin(megacomplex_name: str, full_plugin_name: str) -> None: + """Set the plugin used for a specific megacomplex name. + + This function is useful when you want to resolve conflicts of installed plugins + or overwrite the plugin used for a specific megacomplex name. + + Effected functions: + + - :func:`optimize` + + Parameters + ---------- + megacomplex_name : str + Name of the megacomplex to use the plugin for. + full_plugin_name : str + Full name (import path) of the registered plugin. + """ + set_plugin( + plugin_register_key=megacomplex_name, + full_plugin_name=full_plugin_name, + plugin_registry=__PluginRegistry.megacomplex, + plugin_register_key_name="megacomplex_name", + ) + + +def megacomplex_plugin_table( + *, plugin_names: bool = False, full_names: bool = False +) -> MarkdownStr: + """Return registered megacomplex plugins as markdown table. + + This is especially useful when you work with new plugins. + + Parameters + ---------- + plugin_names : bool + Whether or not to add the names of the plugins to the table. + full_names : bool + Whether to display the full names the plugins are + registered under as well. + + Returns + ------- + MarkdownStr + Markdown table of megacomplexnames. + """ + table_data = [] + megacomplex_names = known_megacomplex_names(full_names=full_names) + header_values = ["Megacomplex name"] + if plugin_names: + header_values.append("Plugin name") + for megacomplex_name in megacomplex_names: + table_data.append( + [ + f"`{megacomplex_name}`", + f"`{full_plugin_name(get_megacomplex(megacomplex_name))}`", + ] + ) + else: + table_data = [[f"`{megacomplex_name}`"] for megacomplex_name in megacomplex_names] + headers = tuple(map(lambda x: f"__{x}__", header_values)) + return MarkdownStr(tabulate(table_data, tablefmt="github", headers=headers, stralign="center")) diff --git a/glotaran/plugin_system/model_registration.py b/glotaran/plugin_system/model_registration.py deleted file mode 100644 index f8493292b..000000000 --- a/glotaran/plugin_system/model_registration.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Model registration convenience functions.""" -from __future__ import annotations - -from typing import TYPE_CHECKING - -from tabulate import tabulate - -from glotaran.plugin_system.base_registry import __PluginRegistry -from glotaran.plugin_system.base_registry import add_plugin_to_registry -from glotaran.plugin_system.base_registry import full_plugin_name -from glotaran.plugin_system.base_registry import get_plugin_from_registry -from glotaran.plugin_system.base_registry import is_registered_plugin -from glotaran.plugin_system.base_registry import registered_plugins -from glotaran.plugin_system.base_registry import set_plugin -from glotaran.utils.ipython import MarkdownStr - -if TYPE_CHECKING: - from glotaran.model import Model - - -def register_model(model_type: str, model: type[Model]) -> None: - """Add a model to the model registry. - - Parameters - ---------- - model_type : str - Name of the model under which it is registered. - model : type[Model] - model class to be registered. - """ - add_plugin_to_registry( - plugin_register_key=model_type, - plugin=model, - plugin_registry=__PluginRegistry.model, - plugin_set_func_name="set_model_plugin", - ) - - -def is_known_model(model_type: str) -> bool: - """Check if a model is in the model registry. - - Parameters - ---------- - model_type : str - Name of the model under which it is registered. - - Returns - ------- - bool - Whether or not the model is registered. - """ - return is_registered_plugin( - plugin_register_key=model_type, plugin_registry=__PluginRegistry.model - ) - - -def get_model(model_type: str) -> type[Model]: - """Retrieve a model from the model registry. - - Parameters - ---------- - model_type : str - Name of the model under which it is registered. - - Returns - ------- - type[Model] - Model class - """ - return get_plugin_from_registry( - plugin_register_key=model_type, - plugin_registry=__PluginRegistry.model, - not_found_error_message=( - f"Unknown model type {model_type!r}. " - f"Known model types are: {known_model_names(full_names=True)}" - ), - ) - - -def known_model_names(full_names: bool = False) -> list[str]: - """Names of the registered models. - - Parameters - ---------- - full_names : bool - Whether to display the full names the plugins are - registered under as well. - - Returns - ------- - list[str] - List of registered models. - """ - return registered_plugins(__PluginRegistry.model, full_names=full_names) - - -def set_model_plugin(model_name: str, full_plugin_name: str) -> None: - """Set the plugin used for a specific model name. - - This function is useful when you want to resolve conflicts of installed plugins - or overwrite the plugin used for a specific model name. - - Effected functions: - - - :func:`optimize` - - Parameters - ---------- - model_name : str - Name of the model to use the plugin for. - full_plugin_name : str - Full name (import path) of the registered plugin. - """ - set_plugin( - plugin_register_key=model_name, - full_plugin_name=full_plugin_name, - plugin_registry=__PluginRegistry.model, - plugin_register_key_name="model_name", - ) - - -def model_plugin_table(*, plugin_names: bool = False, full_names: bool = False) -> MarkdownStr: - """Return registered model plugins as markdown table. - - This is especially useful when you work with new plugins. - - Parameters - ---------- - plugin_names : bool - Whether or not to add the names of the plugins to the table. - full_names : bool - Whether to display the full names the plugins are - registered under as well. - - Returns - ------- - MarkdownStr - Markdown table of modelnames. - """ - table_data = [] - model_names = known_model_names(full_names=full_names) - header_values = ["Model name"] - if plugin_names: - header_values.append("Plugin name") - for model_name in model_names: - table_data.append([f"`{model_name}`", f"`{full_plugin_name(get_model(model_name))}`"]) - else: - table_data = [[f"`{model_name}`"] for model_name in model_names] - headers = tuple(map(lambda x: f"__{x}__", header_values)) - return MarkdownStr(tabulate(table_data, tablefmt="github", headers=headers, stralign="center")) diff --git a/glotaran/plugin_system/test/test_base_registry.py b/glotaran/plugin_system/test/test_base_registry.py index 1297385a9..89614a955 100644 --- a/glotaran/plugin_system/test/test_base_registry.py +++ b/glotaran/plugin_system/test/test_base_registry.py @@ -11,10 +11,10 @@ from glotaran.builtin.io.sdt.sdt_file_reader import SdtDataIo from glotaran.builtin.io.yml.yml import YmlProjectIo -from glotaran.builtin.models.kinetic_image import KineticImageModel +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.io.interface import DataIoInterface from glotaran.io.interface import ProjectIoInterface -from glotaran.model.base_model import Model +from glotaran.model.megacomplex import Megacomplex from glotaran.plugin_system.base_registry import PluginOverwriteWarning from glotaran.plugin_system.base_registry import add_instantiated_plugin_to_registry from glotaran.plugin_system.base_registry import add_plugin_to_registry @@ -57,7 +57,7 @@ def some_method(self): "sdt": SdtDataIo("sdt"), "yml": YmlProjectIo("yml"), "glotaran.builtin.io.sdt.sdt_file_reader.SdtDataIo": SdtDataIo("sdt"), - "kinetic-image": KineticImageModel, + "decay": DecayMegacomplex, "mock_plugin": MockPlugin, "imported_plugin": MockPlugin, }, @@ -66,7 +66,7 @@ def some_method(self): mock_registry_project_io = cast( MutableMapping[str, ProjectIoInterface], copy(mock_registry_data_io) ) -mock_registry_model = cast(MutableMapping[str, Type[Model]], copy(mock_registry_data_io)) +mock_registry_model = cast(MutableMapping[str, Type[Megacomplex]], copy(mock_registry_data_io)) @pytest.mark.parametrize( @@ -111,7 +111,7 @@ def test_PluginOverwriteWarning(): ( ("sdt_new", SdtDataIo("sdt"), copy(mock_registry_data_io)), ("yml_new", YmlProjectIo("yml"), copy(mock_registry_project_io)), - ("kinetic-image_new", KineticImageModel, copy(mock_registry_model)), + ("decay_new", DecayMegacomplex, copy(mock_registry_model)), ), ) def test_add_plugin_to_register( @@ -253,7 +253,7 @@ def test_registered_plugins(): result = [ "sdt", "yml", - "kinetic-image", + "decay", "mock_plugin", "imported_plugin", ] @@ -368,7 +368,7 @@ def test_methods_differ_from_baseclass_table( def get_plugin_function(plugin_registry_key: str): if plugin_registry_key == "base": return MockPlugin - elif plugin_registry_key in ["sub_class", "sub_class_inst"]: + elif plugin_registry_key in {"sub_class", "sub_class_inst"}: return MockPluginSubclass result = methods_differ_from_baseclass_table( diff --git a/glotaran/plugin_system/test/test_megacomplex_registration.py b/glotaran/plugin_system/test/test_megacomplex_registration.py new file mode 100644 index 000000000..64a2865bd --- /dev/null +++ b/glotaran/plugin_system/test/test_megacomplex_registration.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING + +import pytest + +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Megacomplex +from glotaran.model import megacomplex +from glotaran.plugin_system.base_registry import PluginOverwriteWarning +from glotaran.plugin_system.base_registry import __PluginRegistry +from glotaran.plugin_system.megacomplex_registration import get_megacomplex +from glotaran.plugin_system.megacomplex_registration import is_known_megacomplex +from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names +from glotaran.plugin_system.megacomplex_registration import megacomplex_plugin_table +from glotaran.plugin_system.megacomplex_registration import register_megacomplex +from glotaran.plugin_system.megacomplex_registration import set_megacomplex_plugin + +if TYPE_CHECKING: + from _pytest.monkeypatch import MonkeyPatch + + +@pytest.fixture +def mocked_registry(monkeypatch: MonkeyPatch): + monkeypatch.setattr( + __PluginRegistry, + "megacomplex", + { + "foo": Megacomplex, + "bar": DecayMegacomplex, + "glotaran.builtin.megacomplexes.decay.DecayMegacomplex": DecayMegacomplex, + }, + ) + + +@pytest.mark.usefixtures("mocked_registry") +def test_register_megacomplex(): + """Register new megacomplex.""" + register_megacomplex("base-megacomplex", Megacomplex) + + assert "base-megacomplex" in __PluginRegistry.megacomplex + assert __PluginRegistry.megacomplex["base-megacomplex"] == Megacomplex + assert "glotaran.model.megacomplex.Megacomplex" in __PluginRegistry.megacomplex + assert __PluginRegistry.megacomplex["glotaran.model.megacomplex.Megacomplex"] == Megacomplex + assert known_megacomplex_names(full_names=True) == sorted( + [ + "foo", + "bar", + "glotaran.builtin.megacomplexes.decay.DecayMegacomplex", + "base-megacomplex", + "glotaran.model.megacomplex.Megacomplex", + ] + ) + + +@pytest.mark.usefixtures("mocked_registry") +def test_register_megacomplex_warning(): + """PluginOverwriteWarning raised pointing to correct file.""" + + with pytest.warns(PluginOverwriteWarning, match="DecayMegacomplex.+bar.+Dummy") as record: + + @megacomplex(register_as="bar") + class Dummy(DecayMegacomplex): + pass + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + + +@pytest.mark.usefixtures("mocked_registry") +def test_is_known_megacomplex(): + """Check if megacomplexs are in registry""" + assert is_known_megacomplex("foo") + assert is_known_megacomplex("bar") + assert not is_known_megacomplex("baz") + + +@pytest.mark.usefixtures("mocked_registry") +def test_get_megacomplex(): + """Get megacomplex from registry""" + assert get_megacomplex("foo") == Megacomplex + + +@pytest.mark.usefixtures("mocked_registry") +def test_known_megacomplex_names(): + """Get megacomplex names from registry""" + assert known_megacomplex_names() == sorted(["foo", "bar"]) + + +@pytest.mark.usefixtures("mocked_registry") +def test_known_set_megacomplex_plugin(): + """Overwrite foo megacomplex""" + assert get_megacomplex("foo") == Megacomplex + set_megacomplex_plugin("foo", "glotaran.builtin.megacomplexes.decay.DecayMegacomplex") + assert get_megacomplex("foo") == DecayMegacomplex + + +@pytest.mark.usefixtures("mocked_registry") +def test_known_set_megacomplex_plugin_dot_in_megacomplex_name(): + """Raise error if megacomplex_name contains '.'""" + with pytest.raises( + ValueError, + match=r"The value of 'megacomplex_name' isn't allowed to contain the character '\.' \.", + ): + set_megacomplex_plugin("foo.bar", "glotaran.builtin.megacomplexes.decay.DecayMegacomplex") + + +@pytest.mark.usefixtures("mocked_registry") +def test_megacomplex_plugin_table(): + """Short megacomplex table.""" + expected = dedent( + """\ + | __Megacomplex name__ | + |------------------------| + | `bar` | + | `foo` | + """ + ) + print(f"{megacomplex_plugin_table()}\n") + assert f"{megacomplex_plugin_table()}\n" == expected + + +@pytest.mark.usefixtures("mocked_registry") +def test_megacomplex_plugin_table_full(): + """Full Table with all extras.""" + expected = dedent( + """\ + | __Megacomplex name__ | __Plugin name__ | + |---------------------------------------------------------|---------------------------------------------------------------------------| + | `bar` | `glotaran.builtin.megacomplexes.decay.decay_megacomplex.DecayMegacomplex` | + | `foo` | `glotaran.model.megacomplex.Megacomplex` | + | `glotaran.builtin.megacomplexes.decay.DecayMegacomplex` | `glotaran.builtin.megacomplexes.decay.decay_megacomplex.DecayMegacomplex` | + """ # noqa: E501 + ) + + assert f"{megacomplex_plugin_table(plugin_names=True,full_names=True)}\n" == expected diff --git a/glotaran/plugin_system/test/test_model_registration.py b/glotaran/plugin_system/test/test_model_registration.py deleted file mode 100644 index d838c9d00..000000000 --- a/glotaran/plugin_system/test/test_model_registration.py +++ /dev/null @@ -1,150 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from textwrap import dedent -from typing import TYPE_CHECKING - -import pytest - -from glotaran.builtin.models.kinetic_image import KineticImageModel -from glotaran.model import Model -from glotaran.model import model -from glotaran.model.attribute import model_attribute -from glotaran.model.megacomplex import Megacomplex -from glotaran.plugin_system.base_registry import PluginOverwriteWarning -from glotaran.plugin_system.base_registry import __PluginRegistry -from glotaran.plugin_system.model_registration import get_model -from glotaran.plugin_system.model_registration import is_known_model -from glotaran.plugin_system.model_registration import known_model_names -from glotaran.plugin_system.model_registration import model_plugin_table -from glotaran.plugin_system.model_registration import register_model -from glotaran.plugin_system.model_registration import set_model_plugin - -if TYPE_CHECKING: - from _pytest.monkeypatch import MonkeyPatch - - -@pytest.fixture -def mocked_registry(monkeypatch: MonkeyPatch): - monkeypatch.setattr( - __PluginRegistry, - "model", - { - "foo": Model, - "bar": KineticImageModel, - "glotaran.builtin.models.kinetic_image.KineticImageModel": KineticImageModel, - }, - ) - - -@pytest.mark.usefixtures("mocked_registry") -def test_register_model(): - """Register new model.""" - register_model("base-model", Model) - - assert "base-model" in __PluginRegistry.model - assert __PluginRegistry.model["base-model"] == Model - assert "glotaran.model.base_model.Model" in __PluginRegistry.model - assert __PluginRegistry.model["glotaran.model.base_model.Model"] == Model - assert known_model_names(full_names=True) == sorted( - [ - "foo", - "bar", - "glotaran.builtin.models.kinetic_image.KineticImageModel", - "base-model", - "glotaran.model.base_model.Model", - ] - ) - - -@pytest.mark.usefixtures("mocked_registry") -def test_register_model_warning(): - """PluginOverwriteWarning raised pointing to correct file.""" - - @model_attribute() - class DummyAttr(Megacomplex): - pass - - with pytest.warns(PluginOverwriteWarning, match="KineticImageModel.+bar.+Dummy") as record: - - @model( - "bar", - attributes={}, - megacomplex_types=DummyAttr, - model_dimension="", - global_dimension="", - ) - class Dummy(Model): - pass - - assert len(record) == 1 - assert Path(record[0].filename) == Path(__file__) - - -@pytest.mark.usefixtures("mocked_registry") -def test_is_known_model(): - """Check if models are in registry""" - assert is_known_model("foo") - assert is_known_model("bar") - assert not is_known_model("baz") - - -@pytest.mark.usefixtures("mocked_registry") -def test_get_model(): - """Get model from registry""" - assert get_model("foo") == Model - - -@pytest.mark.usefixtures("mocked_registry") -def test_known_model_names(): - """Get model names from registry""" - assert known_model_names() == sorted(["foo", "bar"]) - - -@pytest.mark.usefixtures("mocked_registry") -def test_known_set_model_plugin(): - """Overwrite foo model""" - assert get_model("foo") == Model - set_model_plugin("foo", "glotaran.builtin.models.kinetic_image.KineticImageModel") - assert get_model("foo") == KineticImageModel - - -@pytest.mark.usefixtures("mocked_registry") -def test_known_set_model_plugin_dot_in_model_name(): - """Raise error if model_name contains '.'""" - with pytest.raises( - ValueError, - match=r"The value of 'model_name' isn't allowed to contain the character '\.' \.", - ): - set_model_plugin("foo.bar", "glotaran.builtin.models.kinetic_image.KineticImageModel") - - -@pytest.mark.usefixtures("mocked_registry") -def test_model_plugin_table(): - """Short model table.""" - expected = dedent( - """\ - | __Model name__ | - |------------------| - | `bar` | - | `foo` | - """ - ) - - assert f"{model_plugin_table()}\n" == expected - - -@pytest.mark.usefixtures("mocked_registry") -def test_model_plugin_table_full(): - """Full Table with all extras.""" - expected = dedent( - """\ - | __Model name__ | __Plugin name__ | - |-----------------------------------------------------------|-------------------------------------------------------------------------------| - | `bar` | `glotaran.builtin.models.kinetic_image.kinetic_image_model.KineticImageModel` | - | `foo` | `glotaran.model.base_model.Model` | - | `glotaran.builtin.models.kinetic_image.KineticImageModel` | `glotaran.builtin.models.kinetic_image.kinetic_image_model.KineticImageModel` | - """ # noqa: E501 - ) - - assert f"{model_plugin_table(plugin_names=True,full_names=True)}\n" == expected diff --git a/glotaran/project/scheme.py b/glotaran/project/scheme.py index 0b744b2fe..bd5df63a0 100644 --- a/glotaran/project/scheme.py +++ b/glotaran/project/scheme.py @@ -33,6 +33,7 @@ class Scheme: model: Model | str parameters: ParameterGroup | str data: dict[str, xr.DataArray | xr.Dataset | str] + group: bool = True group_tolerance: float = 0.0 non_negative_least_squares: bool = False maximum_number_function_evaluations: int = None diff --git a/glotaran/project/test/test_result.py b/glotaran/project/test/test_result.py index 523d03758..60c27fd37 100644 --- a/glotaran/project/test/test_result.py +++ b/glotaran/project/test/test_result.py @@ -14,19 +14,17 @@ def dummy_result(): """Dummy result for testing.""" - model = suite.model - - model.is_grouped = False - model.is_index_dependent = False - wanted_parameters = suite.wanted_parameters data = {} for i in range(3): - e_axis = getattr(suite, "e_axis" if i == 0 else f"e_axis{i+1}") - c_axis = getattr(suite, "c_axis" if i == 0 else f"c_axis{i+1}") + global_axis = getattr(suite, "global_axis" if i == 0 else f"global_axis{i+1}") + model_axis = getattr(suite, "model_axis" if i == 0 else f"model_axis{i+1}") data[f"dataset{i+1}"] = simulate( - suite.sim_model, f"dataset{i+1}", wanted_parameters, {"e": e_axis, "c": c_axis} + suite.sim_model, + f"dataset{i+1}", + wanted_parameters, + {"global": global_axis, "model": model_axis}, ) scheme = Scheme( model=suite.model, diff --git a/glotaran/project/test/test_scheme.py b/glotaran/project/test/test_scheme.py index 43cc1e0f2..121fb999c 100644 --- a/glotaran/project/test/test_scheme.py +++ b/glotaran/project/test/test_scheme.py @@ -9,10 +9,18 @@ @pytest.fixture def mock_scheme(tmpdir): + model_yml_str = """ + megacomplex: + m1: + type: decay + k_matrix: [] + dataset: + dataset1: + megacomplex: [m1] + """ model_path = tmpdir.join("model.yml") with open(model_path, "w") as f: - model = "type: kinetic-spectrum\ndataset:\n dataset1:\n megacomplex: []" - f.write(model) + f.write(model_yml_str) parameter_path = tmpdir.join("parameters.yml") with open(parameter_path, "w") as f: @@ -48,7 +56,6 @@ def mock_scheme(tmpdir): def test_scheme(mock_scheme: Scheme): assert mock_scheme.model is not None - assert mock_scheme.model.model_type == "kinetic-spectrum" assert mock_scheme.parameters is not None assert mock_scheme.parameters.get("1") == 1.0 diff --git a/glotaran/builtin/models/kinetic_spectrum/test/__init__.py b/glotaran/test/__init__.py similarity index 100% rename from glotaran/builtin/models/kinetic_spectrum/test/__init__.py rename to glotaran/test/__init__.py diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_kinetic_spectrum_model.py b/glotaran/test/test_spectral_decay.py similarity index 91% rename from glotaran/builtin/models/kinetic_spectrum/test/test_kinetic_spectrum_model.py rename to glotaran/test/test_spectral_decay.py index 355eba34d..9c0fb5995 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_kinetic_spectrum_model.py +++ b/glotaran/test/test_spectral_decay.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate @@ -8,20 +9,23 @@ from glotaran.project import Scheme MODEL_1C_BASE = """\ -type: kinetic-spectrum dataset: dataset1: &dataset1 megacomplex: [mc1] + global_megacomplex: [mc2] initial_concentration: j1 - shape: - s1: sh1 initial_concentration: j1: compartments: [s1] parameters: ["1"] megacomplex: mc1: + type: decay k_matrix: [k1] + mc2: + type: spectral + shape: + s1: sh1 k_matrix: k1: matrix: @@ -78,8 +82,8 @@ class OneComponentOneChannel: model = load_model(MODEL_1C_NO_IRF, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_1C_INITIAL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_1C_WANTED, format_name="yml_str") - time = np.asarray(np.arange(0, 50, 1.5)) - spectral = np.asarray([0]) + time = xr.DataArray(np.arange(0, 50, 1.5)) + spectral = xr.DataArray([0]) axis = {"time": time, "spectral": spectral} @@ -87,25 +91,28 @@ class OneComponentOneChannelGaussianIrf: model = load_model(MODEL_1C_GAUSSIAN_IRF, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_1C_GAUSSIAN_IRF_INITIAL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_1C_GAUSSIAN_WANTED, format_name="yml_str") - time = np.asarray(np.arange(-10, 50, 1.5)) - spectral = np.asarray([0]) + time = xr.DataArray(np.arange(-10, 50, 1.5)) + spectral = xr.DataArray([0]) axis = {"time": time, "spectral": spectral} MODEL_3C_BASE = """\ -type: kinetic-spectrum dataset: dataset1: &dataset1 megacomplex: [mc1] + global_megacomplex: [mc2] initial_concentration: j1 irf: irf1 +megacomplex: + mc1: + type: decay + k_matrix: [k1] + mc2: + type: spectral shape: s1: sh1 s2: sh2 s3: sh3 -megacomplex: - mc1: - k_matrix: [k1] irf: irf1: type: spectral-multi-gaussian @@ -113,17 +120,17 @@ class OneComponentOneChannelGaussianIrf: width: [irf.width] shape: sh1: - type: gaussian + type: skewed-gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: gaussian + type: skewed-gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: gaussian + type: skewed-gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 @@ -220,8 +227,8 @@ class ThreeComponentParallel: model = load_model(MODEL_3C_PARALLEL, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_PARALLEL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_3C_PARALLEL_WANTED, format_name="yml_str") - time = np.arange(-10, 100, 1.5) - spectral = np.arange(600, 750, 10) + time = xr.DataArray(np.arange(-10, 100, 1.5)) + spectral = xr.DataArray(np.arange(600, 750, 10)) axis = {"time": time, "spectral": spectral} @@ -229,8 +236,8 @@ class ThreeComponentSequential: model = load_model(MODEL_3C_SEQUENTIAL, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_SEQUENTIAL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_3C_SIM_SEQUENTIAL, format_name="yml_str") - time = np.asarray(np.arange(-10, 50, 1.0)) - spectral = np.arange(600, 750, 5.0) + time = xr.DataArray(np.arange(-10, 50, 1.0)) + spectral = xr.DataArray(np.arange(600, 750, 5.0)) axis = {"time": time, "spectral": spectral} diff --git a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py similarity index 81% rename from glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py rename to glotaran/test/test_spectral_penalties.py index 99c8b627b..7e4fcb1c0 100644 --- a/glotaran/builtin/models/kinetic_spectrum/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -6,11 +6,14 @@ from copy import deepcopy import numpy as np +import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate -from glotaran.builtin.models.kinetic_spectrum import KineticSpectrumModel +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.io import prepare_time_trace_dataset +from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -20,10 +23,22 @@ DatasetSpec = namedtuple("DatasetSpec", "times wavelengths irf shapes") IrfSpec = namedtuple("IrfSpec", "location width") ShapeSpec = namedtuple("ShapeSpec", "amplitude location width") -ModelSpec = namedtuple("ModelSpec", "base shape dataset_shape equ_area") +ModelSpec = namedtuple("ModelSpec", "base shape spectral_megacomplex equ_area") OptimizationSpec = namedtuple("OptimizationSpec", "nnls max_nfev") +class SpectralDecayModel(Model): + @classmethod + def from_dict(cls, model_dict): + return super().from_dict( + model_dict, + megacomplex_types={ + "decay": DecayMegacomplex, + "spectral": SpectralMegacomplex, + }, + ) + + def plot_overview(res, title=None): """very simple plot helper function derived from pyglotaran_extras""" import matplotlib.pyplot as plt @@ -51,17 +66,17 @@ def plot_overview(res, title=None): plt.show(block=False) -def notest_equal_area_penalties(debug=False): +def test_equal_area_penalties(debug=False): # %% optim_spec = OptimizationSpec(nnls=True, max_nfev=999) noise_spec = NoiseSpec(active=True, seed=1, std_dev=1e-8) - wavelengths = np.arange(650, 670, 2) - time_p1 = np.linspace(-1, 2, 50, endpoint=False) - time_p2 = np.linspace(2, 10, 30, endpoint=False) - time_p3 = np.geomspace(10, 50, num=20) - times = np.concatenate([time_p1, time_p2, time_p3]) + wavelengths = xr.DataArray(np.arange(650, 670, 2)) + time_p1 = xr.DataArray(np.linspace(-1, 2, 50, endpoint=False)) + time_p2 = xr.DataArray(np.linspace(2, 10, 30, endpoint=False)) + time_p3 = xr.DataArray(np.geomspace(10, 50, num=20)) + times = xr.DataArray(np.concatenate([time_p1, time_p2, time_p3])) irf_loc = float(times[20]) irf_width = float((times[1] - times[0]) * 10) @@ -88,7 +103,7 @@ def notest_equal_area_penalties(debug=False): }, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -113,13 +128,13 @@ def notest_equal_area_penalties(debug=False): shape = { "shape": { "sh1": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", @@ -127,15 +142,16 @@ def notest_equal_area_penalties(debug=False): } } - dataset_shape = { + spectral_megacomplex = { + "type": "spectral", "shape": { "s1": "sh1", "s2": "sh2", - } + }, } equ_area = { - "equal_area_penalties": [ + "clp_area_penalties": [ { "source": "s1", "target": "s2", @@ -146,7 +162,7 @@ def notest_equal_area_penalties(debug=False): }, ], } - mspec = ModelSpec(base, shape, dataset_shape, equ_area) + mspec = ModelSpec(base, shape, spectral_megacomplex, equ_area) rela = 1.0 # relation between areas irf = dataset_spec.irf @@ -170,14 +186,15 @@ def notest_equal_area_penalties(debug=False): # derivates: mspec_sim = dict(deepcopy(mspec.base), **mspec.shape) - mspec_sim["dataset"]["dataset1"].update(mspec.dataset_shape) + mspec_sim["megacomplex"]["mc2"] = mspec.spectral_megacomplex + mspec_sim["dataset"]["dataset1"]["global_megacomplex"] = ["mc2"] mspec_fit_wp = dict(deepcopy(mspec.base), **mspec.equ_area) mspec_fit_np = dict(deepcopy(mspec.base)) - model_sim = KineticSpectrumModel.from_dict(mspec_sim) - model_wp = KineticSpectrumModel.from_dict(mspec_fit_wp) - model_np = KineticSpectrumModel.from_dict(mspec_fit_np) + model_sim = SpectralDecayModel.from_dict(mspec_sim) + model_wp = SpectralDecayModel.from_dict(mspec_fit_wp) + model_np = SpectralDecayModel.from_dict(mspec_fit_np) print(model_np) # %% Parameter specification (pspec) @@ -207,7 +224,7 @@ def notest_equal_area_penalties(debug=False): model_sim, "dataset1", param_sim, - axes={"time": times, "spectral": wavelengths}, + coordinates={"time": times, "spectral": wavelengths}, noise=noise_spec.active, noise_std_dev=noise_spec.std_dev, noise_seed=noise_spec.seed, @@ -255,6 +272,7 @@ def notest_equal_area_penalties(debug=False): print(result_wp.data["dataset1"]) area1_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s1")) area2_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s2")) + print("area_np", area1_np, area2_np) assert not np.isclose(area1_np, area2_np) area1_wp = np.sum(result_wp.data["dataset1"].species_associated_spectra.sel(species="s1")) @@ -264,12 +282,10 @@ def notest_equal_area_penalties(debug=False): input_ratio = result_wp.optimized_parameters.get("i.1") / result_wp.optimized_parameters.get( "i.2" ) + print("input", input_ratio) assert np.isclose(input_ratio, 1.5038858115) -# if __name__ == "__main__": -# test__get_idx_from_interval( -# type_factory=list, interval=(500, 600), axis=range(400, 800, 100), expected=(1, 2) -# ) -# test_equal_area_penalties(debug=False) -# test_equal_area_penalties(debug=True) +if __name__ == "__main__": + test_equal_area_penalties(debug=True) + test_equal_area_penalties(debug=False) diff --git a/requirements_dev.txt b/requirements_dev.txt index 8721ca5d3..1a2b4877c 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -14,6 +14,7 @@ xarray==0.19.0 netCDF4==1.5.7 setuptools==41.2 sdtfile==2021.3.21 +typing_inspect==0.7.1 tabulate==0.8.9 # documentation dependencies diff --git a/setup.cfg b/setup.cfg index d90957a3c..3175b7a64 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = sdtfile>=2020.8.3 setuptools>=41.2 tabulate>=0.8.8 + typing_inspect>=0.7.1 xarray>=0.16.2 python_requires = >=3.8, <3.10 setup_requires = @@ -57,10 +58,11 @@ glotaran.plugins.data_io = ascii = glotaran.builtin.io.ascii.wavelength_time_explicit_file sdt = glotaran.builtin.io.sdt.sdt_file_reader nc = glotaran.builtin.io.netCDF.netCDF -glotaran.plugins.model = - kinetic_image = glotaran.builtin.models.kinetic_image - kinetic_spectrum = glotaran.builtin.models.kinetic_spectrum - spectral = glotaran.builtin.models.spectral +glotaran.plugins.megacomplexes = + baseline = glotaran.builtin.megacomplexes.baseline + coherent_artifact = glotaran.builtin.megacomplexes.coherent_artifact + decay = glotaran.builtin.megacomplexes.decay + spectral = glotaran.builtin.megacomplexes.spectral glotaran.plugins.project_io = yml = glotaran.builtin.io.yml.yml csv = glotaran.builtin.io.csv.csv From ffec6c7168af85de9ced2faa16832cb4ed4926e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Thu, 15 Jul 2021 09:58:09 +0200 Subject: [PATCH 05/29] =?UTF-8?q?=F0=9F=A9=B9=20Fix=20Performance=20Regres?= =?UTF-8?q?sions=20(#740)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added benchmark for Problem class * removed print * ♻️ 'Refactored by Sourcery' * 🧹 Moved glotaran/analysis/test/test_relations.py to benchmark/pytest/analysis/test_problem.py Use 'pytest benchmark/pytest/' to run the benchmarks * Numerous performance tweaks * Don't weight data with ones if no weight supplied * Fic duplicate call to create result_dataset * Removed dead code * Switched back to pure numpy in problem * Fixed example * Cleanunp * Update glotaran/analysis/problem_ungrouped.py Co-authored-by: Sebastian Weigand * Update glotaran/analysis/problem_ungrouped.py Co-authored-by: Sebastian Weigand * Update glotaran/analysis/util.py Co-authored-by: Sebastian Weigand Co-authored-by: Sourcery AI <> Co-authored-by: s-weigand --- .gitignore | 1 + benchmark/pytest/analysis/test_problem.py | 157 +++++++++++ glotaran/analysis/problem.py | 75 +----- glotaran/analysis/problem_grouped.py | 244 ++++++++++-------- glotaran/analysis/problem_ungrouped.py | 191 ++++++-------- glotaran/analysis/simulation.py | 13 +- glotaran/analysis/test/models.py | 11 +- glotaran/analysis/test/test_constraints.py | 13 +- glotaran/analysis/test/test_optimization.py | 2 - glotaran/analysis/test/test_penalties.py | 4 +- glotaran/analysis/test/test_problem.py | 33 +-- glotaran/analysis/test/test_relations.py | 15 +- glotaran/analysis/util.py | 186 +++++++------ .../baseline/baseline_megacomplex.py | 7 +- .../test/test_baseline_megacomplex.py | 13 +- .../coherent_artifact_megacomplex.py | 9 +- .../test/test_coherent_artifact.py | 10 +- .../megacomplexes/decay/decay_megacomplex.py | 7 +- .../decay/test/test_decay_megacomplex.py | 16 +- .../decay/test/test_spectral_irf.py | 9 +- .../spectral/spectral_megacomplex.py | 8 +- .../spectral/test/test_spectral_model.py | 18 +- .../builtin/models/spectral/test/__init__.py | 0 glotaran/examples/sequential.py | 5 +- glotaran/model/clp_penalties.py | 8 +- glotaran/model/dataset_model.py | 29 ++- glotaran/model/test/test_model.py | 2 +- glotaran/test/test_spectral_decay.py | 99 +------ glotaran/test/test_spectral_penalties.py | 11 +- 29 files changed, 622 insertions(+), 574 deletions(-) create mode 100644 benchmark/pytest/analysis/test_problem.py delete mode 100644 glotaran/builtin/models/spectral/test/__init__.py diff --git a/.gitignore b/.gitignore index 7cbcbfb26..6329ca758 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ _summary.ps # benchmark results benchmark/.asv +.benchmarks/ diff --git a/benchmark/pytest/analysis/test_problem.py b/benchmark/pytest/analysis/test_problem.py new file mode 100644 index 000000000..265060feb --- /dev/null +++ b/benchmark/pytest/analysis/test_problem.py @@ -0,0 +1,157 @@ +import numpy as np +import pytest +import xarray as xr + +from glotaran.analysis.problem_grouped import GroupedProblem +from glotaran.analysis.problem_ungrouped import UngroupedProblem +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.model import megacomplex +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + +TEST_AXIS_MODEL_SIZE = 100 +TEST_AXIS_MODEL = xr.DataArray(np.arange(0, TEST_AXIS_MODEL_SIZE)) +TEST_AXIS_GLOBAL_SIZE = 100 +TEST_AXIS_GLOBAL = xr.DataArray(np.arange(0, TEST_AXIS_GLOBAL_SIZE)) +TEST_CLP_SIZE = 20 +TEST_CLP_LABELS = [f"{i+1}" for i in range(TEST_CLP_SIZE)] +TEST_MATRIX = np.ones((TEST_AXIS_MODEL_SIZE, TEST_CLP_SIZE)) +# TEST_MATRIX = xr.DataArray( +# np.ones((TEST_AXIS_MODEL_SIZE, TEST_CLP_SIZE)), +# coords=(("test", TEST_AXIS_MODEL.data), ("clp_label", TEST_CLP_LABELS)), +# ) +TEST_DATA = xr.DataArray( + np.ones((TEST_AXIS_GLOBAL_SIZE, TEST_AXIS_MODEL_SIZE)), + coords=(("global", TEST_AXIS_GLOBAL.data), ("test", TEST_AXIS_MODEL.data)), +) +TEST_PARAMETER = ParameterGroup.from_list([]) + + +@megacomplex(dimension="test", properties={"is_index_dependent": bool}) +class BenchmarkMegacomplex(Megacomplex): + def calculate_matrix(self, dataset_model, indices, **kwargs): + return TEST_CLP_LABELS, TEST_MATRIX + + def index_dependent(self, dataset_model): + return self.is_index_dependent + + def finalize_data(self, dataset_model, data): + pass + + +def setup_model(index_dependent): + model_dict = { + "megacomplex": {"m1": {"is_index_dependent": index_dependent}}, + "dataset": { + "dataset1": {"megacomplex": ["m1"]}, + "dataset2": {"megacomplex": ["m1"]}, + "dataset3": {"megacomplex": ["m1"]}, + }, + } + return Model.from_dict( + model_dict, + megacomplex_types={"benchmark": BenchmarkMegacomplex}, + default_megacomplex_type="benchmark", + ) + + +def setup_scheme(model): + return Scheme( + model=model, + parameters=TEST_PARAMETER, + data={ + "dataset1": TEST_DATA, + "dataset2": TEST_DATA, + "dataset3": TEST_DATA, + }, + ) + + +def setup_problem(scheme, grouped): + return GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) + + +def test_benchmark_bag_creation(benchmark): + + model = setup_model(False) + assert model.valid() + + scheme = setup_scheme(model) + problem = setup_problem(scheme, True) + + benchmark(problem.init_bag) + + +@pytest.mark.parametrize("grouped", [True, False]) +@pytest.mark.parametrize("index_dependent", [True, False]) +def test_benchmark_calculate_matrix(benchmark, grouped, index_dependent): + + model = setup_model(index_dependent) + assert model.valid() + + scheme = setup_scheme(model) + problem = setup_problem(scheme, grouped) + + if grouped: + problem.init_bag() + + benchmark(problem.calculate_matrices) + + +@pytest.mark.parametrize("grouped", [True, False]) +@pytest.mark.parametrize("index_dependent", [True, False]) +def test_benchmark_calculate_residual(benchmark, grouped, index_dependent): + + model = setup_model(index_dependent) + assert model.valid() + + scheme = setup_scheme(model) + problem = setup_problem(scheme, grouped) + + if grouped: + problem.init_bag() + problem.calculate_matrices() + + benchmark(problem.calculate_residual) + + +@pytest.mark.parametrize("grouped", [True, False]) +@pytest.mark.parametrize("index_dependent", [True, False]) +def test_benchmark_calculate_result_data(benchmark, grouped, index_dependent): + + model = setup_model(index_dependent) + assert model.valid() + + scheme = setup_scheme(model) + problem = setup_problem(scheme, grouped) + + if grouped: + problem.init_bag() + problem.calculate_matrices() + problem.calculate_residual() + + benchmark(problem.create_result_data) + + +# @pytest.mark.skip(reason="To time consuming atm.") +@pytest.mark.parametrize("grouped", [True, False]) +@pytest.mark.parametrize("index_dependent", [True, False]) +def test_benchmark_optimize_20_runs(benchmark, grouped, index_dependent): + + model = setup_model(index_dependent) + assert model.valid() + + scheme = setup_scheme(model) + problem = setup_problem(scheme, grouped) + + @benchmark + def run(): + if grouped: + problem.init_bag() + + for _ in range(20): + problem.reset() + problem.full_penalty + + problem.create_result_data() diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index 320dfc9ef..400b179e3 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -2,7 +2,6 @@ import warnings from typing import TYPE_CHECKING -from typing import Deque from typing import Dict from typing import NamedTuple from typing import TypeVar @@ -55,7 +54,6 @@ class ProblemGroup(NamedTuple): UngroupedBag = Dict[str, UngroupedProblemDescriptor] -GroupedBag = Deque[ProblemGroup] XrDataContainer = TypeVar("XrDataContainer", xr.DataArray, xr.Dataset) @@ -149,12 +147,6 @@ def parameter_history(self) -> list[ParameterGroup]: def dataset_models(self) -> dict[str, DatasetModel]: return self._dataset_models - @property - def bag(self) -> UngroupedBag | GroupedBag: - if not self._bag: - self.init_bag() - return self._bag - @property def matrices( self, @@ -326,19 +318,15 @@ def _add_weight(self, label, dataset): ) dataset.weight[idx] *= weight.value - def calculate_matrices(self): - raise NotImplementedError - - def calculate_residual(self): - raise NotImplementedError - def create_result_data( self, copy: bool = True, history_index: int | None = None ) -> dict[str, xr.Dataset]: if history_index is not None and history_index != -1: self.parameters = self.parameter_history[history_index] - result_data = {label: self.create_result_dataset(label, copy=copy) for label in self.data} + + self.prepare_result_creation() + result_data = {} for label, dataset_model in self.dataset_models.items(): result_data[label] = self.create_result_dataset(label, copy=copy) dataset_model.finalize_data(result_data[label]) @@ -395,54 +383,6 @@ def _create_svd(self, name: str, dataset: xr.Dataset, lsv_dim: str, rsv_dim: str dataset, name=name, lsv_dim=lsv_dim, rsv_dim=rsv_dim, data_array=data_array ) - def init_bag(self): - """Initializes a problem bag.""" - raise NotImplementedError - - def calculate_index_dependent_matrices( - self, - ) -> tuple[ - dict[str, list[list[str]]], - dict[str, list[np.ndarray]], - dict[str, list[str]], - dict[str, list[np.ndarray]], - ]: - """Calculates the index dependent model matrices.""" - raise NotImplementedError - - def calculate_index_independent_matrices( - self, - ) -> tuple[ - dict[str, list[str]], - dict[str, np.ndarray], - dict[str, list[str]], - dict[str, np.ndarray], - ]: - """Calculates the index independent model matrices.""" - raise NotImplementedError - - def calculate_index_dependent_residual( - self, - ) -> tuple[ - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - ]: - """Calculates the index dependent residuals.""" - raise NotImplementedError - - def calculate_index_independent_residual( - self, - ) -> tuple[ - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - dict[str, list[np.ndarray]], - ]: - """Calculates the index independent residuals.""" - raise NotImplementedError - def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: """Creates a result datasets for index dependent matrices.""" raise NotImplementedError @@ -452,3 +392,12 @@ def create_index_independent_result_dataset( ) -> xr.Dataset: """Creates a result datasets for index independent matrices.""" raise NotImplementedError + + def calculate_matrices(self): + raise NotImplementedError + + def calculate_residual(self): + raise NotImplementedError + + def prepare_result_creation(self): + pass diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index 7a93717b4..5d0b0041e 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -11,6 +11,8 @@ from glotaran.analysis.problem import ParameterError from glotaran.analysis.problem import Problem from glotaran.analysis.problem import ProblemGroup +from glotaran.analysis.util import CalculatedMatrix +from glotaran.analysis.util import apply_weight from glotaran.analysis.util import calculate_clp_penalties from glotaran.analysis.util import calculate_matrix from glotaran.analysis.util import find_closest_index @@ -20,6 +22,8 @@ from glotaran.model import DatasetModel from glotaran.project import Scheme +Bag = Deque[ProblemGroup] + class GroupedProblem(Problem): """Represents a problem where the data is grouped.""" @@ -49,30 +53,37 @@ def __init__(self, scheme: Scheme): self._model_dimension = model_dimensions.pop() self._group_clp_labels = None self._groups = None + self._has_weights = any("weight" in d for d in self._data.values()) + + @property + def bag(self) -> Bag: + if not self._bag: + self.init_bag() + return self._bag def init_bag(self): """Initializes a grouped problem bag.""" + self._bag = None datasets = None - for label in self._model.dataset: - dataset = self._data[label] - if "weight" in dataset: - weight = dataset.weight - data = dataset.data * weight - dataset["weighted_data"] = data - else: - weight = xr.DataArray(np.ones_like(dataset.data), coords=dataset.data.coords) - data = dataset.data - global_axis = dataset.coords[self._global_dimension].values - model_axis = dataset.coords[self._model_dimension].values - has_scaling = self._model.dataset[label].scale is not None + for label, dataset_model in self.dataset_models.items(): + + data = dataset_model.get_data() + weight = dataset_model.get_weight() + if weight is None and self._has_weights: + weight = np.ones_like(data) + + global_axis = dataset_model.get_global_axis() + model_axis = dataset_model.get_model_axis() + has_scaling = dataset_model.scale is not None + if self._bag is None: self._bag = collections.deque( ProblemGroup( - data=data.isel({self._global_dimension: i}).values, - weight=weight.isel({self._global_dimension: i}).values, + data=data[:, i], + weight=weight[:, i] if weight is not None else None, has_scaling=has_scaling, group=label, - data_sizes=[data.isel({self._global_dimension: i}).values.size], + data_sizes=[model_axis.size], descriptor=[ GroupedProblemDescriptor( label, @@ -111,7 +122,7 @@ def _append_to_grouped_bag( for i, j in enumerate(i1): datasets[j].append(label) - data_stripe = data.isel({self._global_dimension: i2[i]}).values + data_stripe = data[:, i2[i]] self._bag[j] = ProblemGroup( data=np.concatenate( [ @@ -119,12 +130,9 @@ def _append_to_grouped_bag( data_stripe, ] ), - weight=np.concatenate( - [ - self._bag[j].weight, - weight.isel({self._global_dimension: i2[i]}).values, - ] - ), + weight=np.concatenate([self._bag[j].weight, weight[:, i2[i]]]) + if weight is not None + else None, has_scaling=has_scaling or self._bag[j].has_scaling, group=self._bag[j].group + label, data_sizes=self._bag[j].data_sizes + [data_stripe.size], @@ -147,10 +155,10 @@ def _append_to_grouped_bag( begin_overlap = i2[0] if len(i2) != 0 else 0 end_overlap = i2[-1] + 1 if len(i2) != 0 else 0 for i in itertools.chain(range(begin_overlap), range(end_overlap, len(global_axis))): - data_stripe = data.isel({self._global_dimension: i}).values + data_stripe = data[:, i] problem = ProblemGroup( data=data_stripe, - weight=weight.isel({self._global_dimension: i}).values, + weight=weight[:, i] if weight is not None else None, has_scaling=has_scaling, group=label, data_sizes=[data_stripe.size], @@ -192,12 +200,12 @@ def calculate_matrices(self): def calculate_index_dependent_matrices( self, - ) -> tuple[dict[str, list[np.ndarray]], list[np.ndarray],]: + ) -> tuple[dict[str, list[CalculatedMatrix]], list[CalculatedMatrix],]: """Calculates the index dependent model matrices.""" def calculate_group( group: ProblemGroup, descriptors: dict[str, DatasetModel] - ) -> tuple[list[xr.DataArray], xr.DataArray, xr.DataArray]: + ) -> tuple[list[CalculatedMatrix], list[str], CalculatedMatrix]: matrices = [ calculate_matrix( descriptors[problem.label], @@ -207,14 +215,14 @@ def calculate_group( ] global_index = group.descriptor[0].indices[self._global_dimension] global_index = group.descriptor[0].axis[self._global_dimension][global_index] - combined_matrix = xr.concat(matrices, dim=self._model_dimension).fillna(0) - group_clp_labels = combined_matrix.coords["clp_label"] + combined_matrix = combine_matrices(matrices) + group_clp_labels = combined_matrix.clp_labels reduced_matrix = reduce_matrix( - combined_matrix, self.model, self.parameters, self._model_dimension, global_index + combined_matrix, self.model, self.parameters, global_index ) return matrices, group_clp_labels, reduced_matrix - results = list(map(lambda group: calculate_group(group, self.dataset_models), self._bag)) + results = list(map(lambda group: calculate_group(group, self.dataset_models), self.bag)) matrices = list(map(lambda result: result[0], results)) @@ -231,7 +239,7 @@ def calculate_group( def calculate_index_independent_matrices( self, - ) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray],]: + ) -> tuple[dict[str, CalculatedMatrix], dict[str, CalculatedMatrix],]: """Calculates the index independent model matrices.""" self._matrices = {} self._group_clp_labels = {} @@ -242,24 +250,22 @@ def calculate_index_independent_matrices( dataset_model, {}, ) - self._group_clp_labels[label] = self._matrices[label].coords["clp_label"] + self._group_clp_labels[label] = self._matrices[label].clp_labels self._reduced_matrices[label] = reduce_matrix( self._matrices[label], self.model, self.parameters, - self._model_dimension, None, ) for group_label, group in self.groups.items(): if group_label not in self._matrices: - self._reduced_matrices[group_label] = xr.concat( - [self._reduced_matrices[label] for label in group], dim=self._model_dimension - ).fillna(0) - group_clp_labels = xr.align( - *(self._matrices[label].coords["clp_label"] for label in group), join="outer" + self._reduced_matrices[group_label] = combine_matrices( + [self._reduced_matrices[label] for label in group] + ) + self._group_clp_labels[group_label] = list( + set(itertools.chain(*(self._matrices[label].clp_labels for label in group))) ) - self._group_clp_labels[group_label] = group_clp_labels[0].coords["clp_label"] return self._matrices, self._reduced_matrices @@ -278,18 +284,13 @@ def calculate_residual(self): else list(map(self._index_independent_residual, self.bag, self._full_axis)) ) - clps = xr.concat(list(map(lambda result: result[0], results)), dim=self._global_dimension) - clps.coords[self._global_dimension] = self._full_axis - reduced_clps = xr.concat( - list(map(lambda result: result[1], results)), dim=self._global_dimension - ) - reduced_clps.coords[self._global_dimension] = self._full_axis - self._ungroup_clps(clps, reduced_clps) + self._clp_labels = list(map(lambda result: result[0], results)) + self._grouped_clps = list(map(lambda result: result[1], results)) self._weighted_residuals = list(map(lambda result: result[2], results)) self._residuals = list(map(lambda result: result[3], results)) self._additional_penalty = calculate_clp_penalties( - self.model, self.parameters, clps, self._global_dimension + self.model, self.parameters, self._clp_labels, self._grouped_clps, self._full_axis ) return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals @@ -297,14 +298,15 @@ def calculate_residual(self): def _index_dependent_residual( self, problem: ProblemGroup, - matrix: np.ndarray, - group_clp_labels: str, + matrix: CalculatedMatrix, + clp_labels: str, index: any, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - matrix = matrix.copy() - for i in range(matrix.shape[1]): - matrix[:, i] *= problem.weight + reduced_clp_labels = matrix.clp_labels + matrix = matrix.matrix + if problem.weight is not None: + apply_weight(matrix, problem.weight) data = problem.data if problem.has_scaling: for i, descriptor in enumerate(problem.descriptor): @@ -314,23 +316,26 @@ def _index_dependent_residual( end = start + problem.data_sizes[i] matrix[start:end, :] *= self.dataset_models[label].scale - reduced_clps, residual = self._residual_function(matrix.values, data) - reduced_clps = xr.DataArray( - reduced_clps, dims=["clp_label"], coords={"clp_label": matrix.coords["clp_label"]} - ) + reduced_clps, weighted_residual = self._residual_function(matrix, data) clps = retrieve_clps( self.model, self.parameters, - group_clp_labels, + clp_labels, + reduced_clp_labels, reduced_clps, index, ) - return clps, reduced_clps, residual, residual / problem.weight + residual = ( + weighted_residual / problem.weight if problem.weight is not None else weighted_residual + ) + return clp_labels, clps, weighted_residual, residual def _index_independent_residual(self, problem: ProblemGroup, index: any): - matrix = self.reduced_matrices[problem.group].copy() - for i in range(matrix.shape[1]): - matrix[:, i] *= problem.weight + matrix = self.reduced_matrices[problem.group] + reduced_clp_labels = matrix.clp_labels + matrix = matrix.matrix.copy() + if problem.weight is not None: + apply_weight(matrix, problem.weight) data = problem.data if problem.has_scaling: for i, descriptor in enumerate(problem.descriptor): @@ -339,65 +344,47 @@ def _index_independent_residual(self, problem: ProblemGroup, index: any): start = sum(problem.data_sizes[0:i]) end = start + problem.data_sizes[i] matrix[start:end, :] *= self.dataset_models[label].scale - reduced_clps, residual = self._residual_function(matrix.values, data) - reduced_clps = xr.DataArray( - reduced_clps, dims=["clp_label"], coords={"clp_label": matrix.coords["clp_label"]} - ) + reduced_clps, weighted_residual = self._residual_function(matrix, data) + clp_labels = self._group_clp_labels[problem.group] clps = retrieve_clps( self.model, self.parameters, - self._group_clp_labels[problem.group], + clp_labels, + reduced_clp_labels, reduced_clps, index, ) - return clps, reduced_clps, residual, residual / problem.weight + residual = ( + weighted_residual / problem.weight if problem.weight is not None else weighted_residual + ) + return clp_labels, clps, weighted_residual, residual - def _ungroup_clps(self, clps: xr.DataArray, reduced_clps: xr.DataArray): - self._reduced_clps = {} + def prepare_result_creation(self): + if self._residuals is None: + self.calculate_residual() + full_clp_labels = self._clp_labels + full_clps = self._grouped_clps self._clps = {} - for label in self.matrices: - clp_labels = ( - [m.coords["clp_label"] for m in self.matrices[label]] - if self._index_dependent - else self.matrices[label].coords["clp_label"] - ) + for label, matrix in self.matrices.items(): + # TODO deal with different clps at indices + clp_labels = matrix[0].clp_labels if self._index_dependent else matrix.clp_labels # find offset in the full axis - offset = find_closest_index( - self.data[label].coords[self._global_dimension][0].values, self._full_axis + global_axis = self.dataset_models[label].get_global_axis() + offset = find_closest_index(global_axis[0], self._full_axis) + + clps = [] + for i in range(global_axis.size): + full_index_clp_labels = full_clp_labels[i + offset] + index_clps = full_clps[i + offset] + mask = [full_index_clp_labels.index(clp_label) for clp_label in clp_labels] + clps.append(index_clps[mask]) + + self._clps[label] = xr.DataArray( + clps, + coords=((self._global_dimension, global_axis), ("clp_label", clp_labels)), ) - self._reduced_clps[label] = [] - self._clps[label] = [] - - for i in range(self.data[label].coords[self._global_dimension].size): - - index_clp_labels = clp_labels[i] if self._index_dependent else clp_labels - index_clps = clps[i + offset] - index_clps = index_clps.sel({"clp_label": index_clp_labels}) - self._clps[label].append(index_clps) - - index_reduced_clps = reduced_clps[i + offset] - index_reduced_clp_labels, _ = xr.align( - index_clp_labels, index_reduced_clps.coords["clp_label"] - ) - index_reduced_clps = index_reduced_clps.sel( - {"clp_label": index_reduced_clp_labels} - ) - self._reduced_clps[label].append(index_reduced_clps) - - self._reduced_clps[label] = xr.concat( - self.reduced_clps[label], dim=self._global_dimension - ) - self._reduced_clps[label].coords[self._global_dimension] = self.data[label].coords[ - self._global_dimension - ] - - self._clps[label] = xr.concat(self._clps[label], dim=self._global_dimension) - self._clps[label].coords[self._global_dimension] = self.data[label].coords[ - self._global_dimension - ] - def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: """Creates a result datasets for index dependent matrices.""" @@ -421,7 +408,7 @@ def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) (self._model_dimension), ("clp_label"), ), - self.matrices[label], + np.asarray([m.matrix for m in self.matrices[label]]), ) dataset["clp"] = self.clps[label] @@ -432,7 +419,13 @@ def create_index_independent_result_dataset( ) -> xr.Dataset: """Creates a result datasets for index independent matrices.""" - dataset["matrix"] = self.matrices[label] + dataset["matrix"] = ( + ( + (self._model_dimension), + ("clp_label"), + ), + self.matrices[label].matrix, + ) dataset["clp"] = self.clps[label] for index, grouped_problem in enumerate(self.bag): @@ -496,3 +489,34 @@ def full_penalty(self) -> np.ndarray: else np.concatenate(residuals) ) return self._full_penalty + + +def combine_matrices(matrices: list[CalculatedMatrix]) -> CalculatedMatrix: + masks = [] + full_clp_labels = None + sizes = [] + dim1 = 0 + for matrix in matrices: + clp_labels = matrix.clp_labels + model_axis_size = matrix.matrix.shape[0] + sizes.append(model_axis_size) + dim1 += model_axis_size + if full_clp_labels is None: + full_clp_labels = clp_labels.copy() + masks.append([i for i, _ in enumerate(clp_labels)]) + else: + mask = [] + for c in clp_labels: + if c not in full_clp_labels: + full_clp_labels.append(c) + mask.append(full_clp_labels.index(c)) + masks.append(mask) + dim2 = len(full_clp_labels) + full_matrix = np.zeros((dim1, dim2), dtype=np.float64) + start = 0 + for i, m in enumerate(matrices): + end = start + sizes[i] + full_matrix[start:end, masks[i]] = m.matrix + start = end + + return CalculatedMatrix(full_clp_labels, full_matrix) diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index e5f030ef5..1acdf5cd6 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -5,7 +5,8 @@ from glotaran.analysis.problem import ParameterError from glotaran.analysis.problem import Problem -from glotaran.analysis.problem import UngroupedProblemDescriptor +from glotaran.analysis.util import CalculatedMatrix +from glotaran.analysis.util import apply_weight from glotaran.analysis.util import calculate_clp_penalties from glotaran.analysis.util import calculate_matrix from glotaran.analysis.util import reduce_matrix @@ -16,29 +17,11 @@ class UngroupedProblem(Problem): """Represents a problem where the data is not grouped.""" - def init_bag(self): - """Initializes an ungrouped problem bag.""" - self._bag = {} - for label, dataset_model in self.dataset_models.items(): - dataset = self._data[label] - data = dataset.data - weight = dataset.weight if "weight" in dataset else None - if weight is not None: - data = data * weight - dataset["weighted_data"] = data - self._bag[label] = UngroupedProblemDescriptor( - dataset_model, - data, - dataset.coords[dataset_model.get_model_dimension()].values, - dataset.coords[dataset_model.get_global_dimension()].values, - weight, - ) - def calculate_matrices( self, ) -> tuple[ - dict[str, list[xr.DataArray] | xr.DataArray], - dict[str, list[xr.DataArray] | xr.DataArray], + dict[str, CalculatedMatrix | list[CalculatedMatrix]], + dict[str, CalculatedMatrix | list[CalculatedMatrix]], ]: """Calculates the model matrices.""" if self._parameters is None: @@ -47,44 +30,32 @@ def calculate_matrices( self._matrices = {} self._reduced_matrices = {} - for label, problem in self.bag.items(): - dataset_model = self.dataset_models[label] + for label, dataset_model in self.dataset_models.items(): if dataset_model.index_dependent(): - self._calculate_index_dependent_matrix(label, problem, dataset_model) + self._calculate_index_dependent_matrix(label, dataset_model) else: - self._calculate_index_independent_matrix(label, problem, dataset_model) + self._calculate_index_independent_matrix(label, dataset_model) return self._matrices, self._reduced_matrices - def _calculate_index_dependent_matrix( - self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetModel - ): + def _calculate_index_dependent_matrix(self, label: str, dataset_model: DatasetModel): self._matrices[label] = [] self._reduced_matrices[label] = [] - for i, index in enumerate(problem.global_axis): + for i, index in enumerate(dataset_model.get_global_axis()): matrix = calculate_matrix( dataset_model, {dataset_model.get_global_dimension(): i}, ) self._matrices[label].append(matrix) - reduced_matrix = reduce_matrix( - matrix, self.model, self.parameters, dataset_model.get_model_dimension(), index - ) + reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, index) self._reduced_matrices[label].append(reduced_matrix) - def _calculate_index_independent_matrix( - self, label: str, problem: UngroupedProblemDescriptor, dataset_model: DatasetModel - ): + def _calculate_index_independent_matrix(self, label: str, dataset_model: DatasetModel): - matrix = calculate_matrix( - dataset_model, - {}, - ) + matrix = calculate_matrix(dataset_model, {}) self._matrices[label] = matrix - reduced_matrix = reduce_matrix( - matrix, self.model, self.parameters, dataset_model.get_model_dimension(), None - ) + reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, None) self._reduced_matrices[label] = reduced_matrix def calculate_residual( @@ -103,81 +74,89 @@ def calculate_residual( self._residuals = {} self._additional_penalty = [] - for label, problem in self.bag.items(): - self._calculate_residual_for_problem(label, problem) + for label, dataset_model in self._dataset_models.items(): + self._calculate_residual(label, dataset_model) self._additional_penalty = ( np.concatenate(self._additional_penalty) if len(self._additional_penalty) != 0 else [] ) return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals - def _calculate_residual_for_problem(self, label: str, problem: UngroupedProblemDescriptor): + def _calculate_residual(self, label: str, dataset_model: DatasetModel): self._reduced_clps[label] = [] self._clps[label] = [] self._weighted_residuals[label] = [] self._residuals[label] = [] - data = problem.data - dataset_model = self.dataset_models[label] - model_dimension = dataset_model.get_model_dimension() - global_dimension = dataset_model.get_global_dimension() - - for i, index in enumerate(problem.global_axis): - clp_labels = ( - self.matrices[label][i].coords["clp_label"] - if dataset_model.index_dependent() - else self.matrices[label].coords["clp_label"] - ) - reduced_matrix = ( + + data = dataset_model.get_data() + global_axis = dataset_model.get_global_axis() + + for i, index in enumerate(global_axis): + reduced_clp_labels, reduced_matrix = ( self.reduced_matrices[label][i] if dataset_model.index_dependent() else self.reduced_matrices[label] ) - if problem.dataset.scale is not None: - reduced_matrix *= self.dataset_models[label].scale + if not dataset_model.index_dependent(): + reduced_matrix = reduced_matrix.copy() - if problem.weight is not None: - for j in range(reduced_matrix.shape[1]): - reduced_matrix[:, j] *= problem.weight.isel({global_dimension: i}).values + if dataset_model.scale is not None: + reduced_matrix *= dataset_model.scale + + weight = dataset_model.get_weight() + if weight is not None: + apply_weight(reduced_matrix, weight[:, i]) + + reduced_clps, residual = self._residual_function(reduced_matrix, data[:, i]) - reduced_clps, residual = self._residual_function( - reduced_matrix.values, data.isel({global_dimension: i}).values - ) - reduced_clps = xr.DataArray( - reduced_clps, - dims=["clp_label"], - coords={"clp_label": reduced_matrix.coords["clp_label"]}, - ) self._reduced_clps[label].append(reduced_clps) + + clp_labels = self._get_clp_labels(label, i) self._clps[label].append( - retrieve_clps(self.model, self.parameters, clp_labels, reduced_clps, index) - ) - residual = xr.DataArray( - residual, - dims=[model_dimension], - coords={model_dimension: reduced_matrix.coords[model_dimension]}, + retrieve_clps( + self.model, + self.parameters, + clp_labels, + reduced_clp_labels, + reduced_clps, + index, + ) ) self._weighted_residuals[label].append(residual) - if problem.weight is not None: - self._residuals[label].append( - residual / problem.weight.isel({global_dimension: i}) - ) + if weight is not None: + self._residuals[label].append(residual / weight[:, i]) else: self._residuals[label].append(residual) - self._reduced_clps[label] = xr.concat(self._reduced_clps[label], dim=global_dimension) - self._reduced_clps[label].coords[global_dimension] = data.coords[global_dimension] - self._clps[label] = xr.concat(self._clps[label], dim=global_dimension) - self._clps[label].coords[global_dimension] = data.coords[global_dimension] + clp_labels = self._get_clp_labels(label) additional_penalty = calculate_clp_penalties( - self.model, self.parameters, self._clps[label], global_dimension + self.model, self.parameters, clp_labels, self._clps[label], global_axis ) if additional_penalty.size != 0: self._additional_penalty.append(additional_penalty) + def _get_clp_labels(self, label: str, index: int = 0): + return ( + self.matrices[label][index].clp_labels + if self.dataset_models[label].index_dependent() + else self.matrices[label].clp_labels + ) + def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: """Creates a result datasets for index dependent matrices.""" - self._add_index_dependent_matrix_to_dataset(label, dataset) + model_dimension = self.dataset_models[label].get_model_dimension() + global_dimension = self.dataset_models[label].get_global_dimension() + + dataset.coords["clp_label"] = self._get_clp_labels(label) + dataset["matrix"] = ( + ( + (global_dimension), + (model_dimension), + ("clp_label"), + ), + np.asarray([m.matrix for m in self.matrices[label]]), + ) self._add_residual_and_full_clp_to_dataset(label, dataset) @@ -188,56 +167,44 @@ def create_index_independent_result_dataset( ) -> xr.Dataset: """Creates a result datasets for index independent matrices.""" - self._add_index_independent_matrix_to_dataset(label, dataset) - - self._add_residual_and_full_clp_to_dataset(label, dataset) - - return dataset - - def _add_index_dependent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): + matrix = self.matrices[label] + dataset.coords["clp_label"] = matrix.clp_labels model_dimension = self.dataset_models[label].get_model_dimension() - global_dimension = self.dataset_models[label].get_global_dimension() - - matrix = xr.concat(self.matrices[label], dim=global_dimension) - matrix.coords[global_dimension] = dataset.coords[global_dimension] - dataset.coords["clp_label"] = matrix.coords["clp_label"] dataset["matrix"] = ( ( - (global_dimension), (model_dimension), ("clp_label"), ), - matrix.data, + matrix.matrix, ) - def _add_index_independent_matrix_to_dataset(self, label: str, dataset: xr.Dataset): - dataset.coords["clp_label"] = self.matrices[label].coords["clp_label"] + self._add_residual_and_full_clp_to_dataset(label, dataset) + + return dataset + + def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset): model_dimension = self.dataset_models[label].get_model_dimension() - dataset["matrix"] = ( + global_dimension = self.dataset_models[label].get_global_dimension() + dataset["clp"] = ( ( - (model_dimension), + (global_dimension), ("clp_label"), ), - self.matrices[label].data, + np.asarray(self.clps[label]), ) - - def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset): - model_dimension = self.dataset_models[label].get_model_dimension() - global_dimension = self.dataset_models[label].get_global_dimension() - dataset["clp"] = self.clps[label] dataset["weighted_residual"] = ( ( (model_dimension), (global_dimension), ), - xr.concat(self.weighted_residuals[label], dim=global_dimension).T.data, + np.transpose(np.asarray(self.weighted_residuals[label])), ) dataset["residual"] = ( ( (model_dimension), (global_dimension), ), - xr.concat(self.residuals[label], dim=global_dimension).T.data, + np.transpose(np.asarray(self.residuals[label])), ) @property diff --git a/glotaran/analysis/simulation.py b/glotaran/analysis/simulation.py index 964557251..80221b909 100644 --- a/glotaran/analysis/simulation.py +++ b/glotaran/analysis/simulation.py @@ -110,8 +110,8 @@ def simulate_clp( for i in range(global_axis.size): index_matrix = matrices[i] if dataset_model.index_dependent() else matrices result.data[:, i] = np.dot( - index_matrix, - clp.isel({global_dimension: i}).sel({"clp_label": index_matrix.coords["clp_label"]}), + index_matrix.matrix, + clp.isel({global_dimension: i}).sel({"clp_label": index_matrix.clp_labels}), ) return result @@ -132,7 +132,14 @@ def simulate_global_model( raise ValueError("Index dependent models for global dimension are not supported.") global_matrix = calculate_matrix(dataset_model, {}, global_model=True) - global_matrix = global_matrix.T + global_clp_labels = global_matrix.clp_labels + global_matrix = xr.DataArray( + global_matrix.matrix.T, + coords=[ + ("clp_label", global_clp_labels), + (dataset_model.get_global_dimension(), dataset_model.get_global_axis()), + ], + ) return simulate_clp( dataset_model, diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 2d5cc9d81..87e96f611 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -3,7 +3,6 @@ from typing import List import numpy as np -import xarray as xr from glotaran.model import Megacomplex from glotaran.model import Model @@ -27,7 +26,7 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): r_compartments.append(compartments[i]) for j in range(axis.shape[0]): array[j, i] = (i + j) * axis[j] - return xr.DataArray(array, coords=(("global", axis.data), ("clp_label", r_compartments))) + return r_compartments, array def index_dependent(self, dataset_model): return False @@ -49,7 +48,7 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): r_compartments.append(compartments[i]) for j in range(axis.shape[0]): array[j, i] = (i + j) * axis[j] - return xr.DataArray(array, coords=(("model", axis.data), ("clp_label", r_compartments))) + return r_compartments, array def index_dependent(self, dataset_model): return self.is_index_dependent @@ -85,7 +84,7 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): else: compartments = [f"s{i+1}" for i in range(len(kinpar))] array = np.exp(np.outer(axis, kinpar)) - return xr.DataArray(array, coords=(("model", axis.data), ("clp_label", compartments))) + return compartments, array def index_dependent(self, dataset_model): return self.is_index_dependent @@ -106,7 +105,7 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): else: compartments = [f"s{i+1}" for i in range(len(kinpar))] array = np.asarray([[1 for _ in range(axis.size)] for _ in compartments]).T - return xr.DataArray(array, coords=(("global", axis.data), ("clp_label", compartments))) + return compartments, array def index_dependent(self, dataset_model): return False @@ -135,7 +134,7 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): -np.log(2) * np.square(2 * (axis - location[i]) / delta[i]) ) compartments = [f"s{i+1}" for i in range(location.size)] - return xr.DataArray(array.T, coords=(("global", axis.data), ("clp_label", compartments))) + return compartments, array.T def index_dependent(self, dataset_model): return False diff --git a/glotaran/analysis/test/test_constraints.py b/glotaran/analysis/test/test_constraints.py index ee43af9da..0914e892c 100644 --- a/glotaran/analysis/test/test_constraints.py +++ b/glotaran/analysis/test/test_constraints.py @@ -17,7 +17,7 @@ def test_constraint(index_dependent, grouped): model.megacomplex["m1"].is_index_dependent = index_dependent model.constraints.append(ZeroConstraint.from_dict({"target": "s2"})) - print("grouped", grouped, "index_dependent", index_dependent) + print("grouped", grouped, "index_dependent", index_dependent) # T001 dataset = simulate( suite.sim_model, "dataset1", @@ -27,7 +27,6 @@ def test_constraint(index_dependent, grouped): scheme = Scheme(model=model, parameters=suite.initial_parameters, data={"dataset1": dataset}) problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) - reduced_clps = problem.reduced_clps["dataset1"] if index_dependent: reduced_matrix = ( problem.reduced_matrices[0] if grouped else problem.reduced_matrices["dataset1"][0] @@ -35,10 +34,12 @@ def test_constraint(index_dependent, grouped): else: reduced_matrix = problem.reduced_matrices["dataset1"] matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] - clps = problem.clps["dataset1"] - assert "s2" not in reduced_clps.coords["clp_label"] - assert "s2" not in reduced_matrix.coords["clp_label"] + result_data = problem.create_result_data() + print(result_data) # T001 + clps = result_data["dataset1"].clp + + assert "s2" not in reduced_matrix.clp_labels assert "s2" in clps.coords["clp_label"] assert clps.sel(clp_label="s2") == 0 - assert "s2" in matrix.coords["clp_label"] + assert "s2" in matrix.clp_labels diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 9d29e9efa..05e8f2050 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -124,8 +124,6 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert np.allclose(dataset.data, resultdata.data) if weight: assert "weight" in resultdata - assert "weighted_data" in resultdata - assert np.allclose(resultdata.data, resultdata.weighted_data * 2) assert "weighted_residual" in resultdata assert "weighted_residual_left_singular_vectors" in resultdata assert "weighted_residual_right_singular_vectors" in resultdata diff --git a/glotaran/analysis/test/test_penalties.py b/glotaran/analysis/test/test_penalties.py index ddee6565c..3c3c3c1b2 100644 --- a/glotaran/analysis/test/test_penalties.py +++ b/glotaran/analysis/test/test_penalties.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("index_dependent", [True, False]) @pytest.mark.parametrize("grouped", [True, False]) -def test_constraint(index_dependent, grouped): +def test_penalties(index_dependent, grouped): model = deepcopy(suite.model) model.megacomplex["m1"].is_index_dependent = index_dependent model.clp_area_penalties.append( @@ -33,7 +33,7 @@ def test_constraint(index_dependent, grouped): global_axis = np.arange(50) - print("grouped", grouped, "index_dependent", index_dependent) + print("grouped", grouped, "index_dependent", index_dependent) # T001 dataset = simulate( suite.sim_model, "dataset1", diff --git a/glotaran/analysis/test/test_problem.py b/glotaran/analysis/test/test_problem.py index 99efd16e7..08344beeb 100644 --- a/glotaran/analysis/test/test_problem.py +++ b/glotaran/analysis/test/test_problem.py @@ -10,6 +10,7 @@ from glotaran.analysis.simulation import simulate from glotaran.analysis.test.models import MultichannelMulticomponentDecay as suite from glotaran.analysis.test.models import SimpleTestModel +from glotaran.analysis.util import CalculatedMatrix from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -36,15 +37,11 @@ def problem(request) -> Problem: def test_problem_bag(problem: Problem): - bag = problem.bag - if problem.grouped: + bag = problem.bag assert isinstance(bag, collections.deque) assert len(bag) == suite.global_axis.size assert problem.groups == {"dataset1": ["dataset1"]} - else: - assert isinstance(bag, dict) - assert "dataset1" in bag def test_problem_matrices(problem: Problem): @@ -52,18 +49,20 @@ def test_problem_matrices(problem: Problem): if problem.grouped: if problem.model.is_index_dependent: - assert all(isinstance(m, xr.DataArray) for m in problem.reduced_matrices) + assert all(isinstance(m, CalculatedMatrix) for m in problem.reduced_matrices) assert len(problem.reduced_matrices) == suite.global_axis.size else: assert "dataset1" in problem.reduced_matrices - assert isinstance(problem.reduced_matrices["dataset1"], xr.DataArray) + assert isinstance(problem.reduced_matrices["dataset1"], CalculatedMatrix) else: if problem.model.is_index_dependent: assert isinstance(problem.reduced_matrices, dict) assert isinstance(problem.reduced_matrices["dataset1"], list) - assert all(isinstance(m, xr.DataArray) for m in problem.reduced_matrices["dataset1"]) + assert all( + isinstance(m, CalculatedMatrix) for m in problem.reduced_matrices["dataset1"] + ) else: - assert isinstance(problem.reduced_matrices["dataset1"], xr.DataArray) + assert isinstance(problem.reduced_matrices["dataset1"], CalculatedMatrix) assert isinstance(problem.matrices, dict) assert "dataset1" in problem.reduced_matrices @@ -78,16 +77,8 @@ def test_problem_residuals(problem: Problem): else: assert isinstance(problem.residuals, dict) assert "dataset1" in problem.residuals - assert all(isinstance(r, xr.DataArray) for r in problem.residuals["dataset1"]) + assert all(isinstance(r, np.ndarray) for r in problem.residuals["dataset1"]) assert len(problem.residuals["dataset1"]) == suite.global_axis.size - assert isinstance(problem.reduced_clps, dict) - assert "dataset1" in problem.reduced_clps - assert all(isinstance(c, xr.DataArray) for c in problem.reduced_clps["dataset1"]) - assert len(problem.reduced_clps["dataset1"]) == suite.global_axis.size - assert isinstance(problem.clps, dict) - assert "dataset1" in problem.clps - assert all(isinstance(c, xr.DataArray) for c in problem.clps["dataset1"]) - assert len(problem.clps["dataset1"]) == suite.global_axis.size def test_problem_result_data(problem: Problem): @@ -155,7 +146,7 @@ def test_prepare_data(): ], } model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) + print(model.validate()) # T001 assert model.valid() parameters = ParameterGroup.from_list([]) @@ -173,7 +164,7 @@ def test_prepare_data(): problem = Problem(scheme) data = problem.data["dataset1"] - print(data) + print(data) # T001 assert "data" in data assert "weight" in data @@ -189,7 +180,7 @@ def test_prepare_data(): } ) model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) + print(model.validate()) # T001 assert model.valid() scheme = Scheme(model, parameters, {"dataset1": dataset}) diff --git a/glotaran/analysis/test/test_relations.py b/glotaran/analysis/test/test_relations.py index c0841efd3..ccc5f6d05 100644 --- a/glotaran/analysis/test/test_relations.py +++ b/glotaran/analysis/test/test_relations.py @@ -13,13 +13,13 @@ @pytest.mark.parametrize("index_dependent", [True, False]) @pytest.mark.parametrize("grouped", [True, False]) -def test_constraint(index_dependent, grouped): +def test_relations(index_dependent, grouped): model = deepcopy(suite.model) model.megacomplex["m1"].is_index_dependent = index_dependent model.relations.append(Relation.from_dict({"source": "s1", "target": "s2", "parameter": "3"})) parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) - print("grouped", grouped, "index_dependent", index_dependent) + print("grouped", grouped, "index_dependent", index_dependent) # T001 dataset = simulate( suite.sim_model, "dataset1", @@ -29,7 +29,6 @@ def test_constraint(index_dependent, grouped): scheme = Scheme(model=model, parameters=parameters, data={"dataset1": dataset}) problem = GroupedProblem(scheme) if grouped else UngroupedProblem(scheme) - reduced_clps = problem.reduced_clps["dataset1"] if index_dependent: reduced_matrix = ( problem.reduced_matrices[0] if grouped else problem.reduced_matrices["dataset1"][0] @@ -37,10 +36,12 @@ def test_constraint(index_dependent, grouped): else: reduced_matrix = problem.reduced_matrices["dataset1"] matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] - clps = problem.clps["dataset1"] - assert "s2" not in reduced_clps.coords["clp_label"] - assert "s2" not in reduced_matrix.coords["clp_label"] + result_data = problem.create_result_data() + print(result_data) # T001 + clps = result_data["dataset1"].clp + + assert "s2" not in reduced_matrix.clp_labels assert "s2" in clps.coords["clp_label"] assert clps.sel(clp_label="s2") == clps.sel(clp_label="s1") * 2 - assert "s2" in matrix.coords["clp_label"] + assert "s2" in matrix.clp_labels diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index dc857e6dd..342d93c60 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -2,7 +2,9 @@ import itertools from typing import Any +from typing import NamedTuple +import numba as nb import numpy as np import xarray as xr @@ -11,6 +13,11 @@ from glotaran.parameter import ParameterGroup +class CalculatedMatrix(NamedTuple): + clp_labels: list[str] + matrix: np.ndarray + + def find_overlap(a, b, rtol=1e-05, atol=1e-08): ovr_a = [] ovr_b = [] @@ -43,7 +50,9 @@ def calculate_matrix( dataset_model: DatasetModel, indices: dict[str, int] | None, global_model: bool = False, -) -> xr.DataArray: +) -> CalculatedMatrix: + + clp_labels = None matrix = None megacomplex_iterator = dataset_model.iterate_megacomplexes @@ -53,66 +62,78 @@ def calculate_matrix( dataset_model.swap_dimensions() for scale, megacomplex in megacomplex_iterator(): - this_matrix = megacomplex.calculate_matrix(dataset_model, indices) + this_clp_labels, this_matrix = megacomplex.calculate_matrix(dataset_model, indices) if scale is not None: this_matrix *= scale if matrix is None: + clp_labels = this_clp_labels matrix = this_matrix else: - matrix, this_matrix = xr.align(matrix, this_matrix, join="outer", copy=False) - matrix = matrix.fillna(0) - matrix += this_matrix.fillna(0) + tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] + tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) + for idx, label in enumerate(tmp_clp_labels): + if label in clp_labels: + tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] + if label in this_clp_labels: + tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] + clp_labels = tmp_clp_labels + matrix = tmp_matrix if global_model: dataset_model.swap_dimensions() - return matrix + return CalculatedMatrix(clp_labels, matrix) + + +@nb.jit(nopython=True, parallel=True) +def apply_weight(matrix, weight): + for i in nb.prange(matrix.shape[1]): + matrix[:, i] *= weight def reduce_matrix( - matrix: xr.DataArray, + matrix: CalculatedMatrix, model: Model, parameters: ParameterGroup, - model_dimension: str, index: Any | None, -) -> xr.DataArray: - matrix = apply_relations(matrix, model, parameters, model_dimension, index) +) -> CalculatedMatrix: + matrix = apply_relations(matrix, model, parameters, index) matrix = apply_constraints(matrix, model, index) return matrix def apply_constraints( - matrix: xr.DataArray, + matrix: CalculatedMatrix, model: Model, index: Any | None, -) -> xr.DataArray: +) -> CalculatedMatrix: if len(model.constraints) == 0: return matrix - clp_labels = matrix.coords["clp_label"].values - removed_clp = [ + clp_labels = matrix.clp_labels + removed_clp_labels = [ c.target for c in model.constraints if c.target in clp_labels and c.applies(index) ] - reduced_clp_label = [c for c in clp_labels if c not in removed_clp] - - return matrix.sel({"clp_label": reduced_clp_label}) + reduced_clp_labels = [c for c in clp_labels if c not in removed_clp_labels] + mask = [label in reduced_clp_labels for label in clp_labels] + reduced_matrix = matrix.matrix[:, mask] + return CalculatedMatrix(reduced_clp_labels, reduced_matrix) def apply_relations( - matrix: xr.DataArray, + matrix: CalculatedMatrix, model: Model, parameters: ParameterGroup, - model_dimension: str, index: Any | None, -) -> xr.DataArray: +) -> CalculatedMatrix: if len(model.relations) == 0: return matrix - clp_labels = list(matrix.coords["clp_label"].values) + clp_labels = matrix.clp_labels relation_matrix = np.diagflat([1.0 for _ in clp_labels]) idx_to_delete = [] @@ -128,30 +149,28 @@ def apply_relations( relation_matrix[target_idx, source_idx] = relation.parameter idx_to_delete.append(target_idx) - reduced_clp_label = [label for i, label in enumerate(clp_labels) if i not in idx_to_delete] + reduced_clp_labels = [label for i, label in enumerate(clp_labels) if i not in idx_to_delete] relation_matrix = np.delete(relation_matrix, idx_to_delete, axis=1) - return xr.DataArray( - matrix.values @ relation_matrix, - dims=matrix.dims, - coords={ - "clp_label": reduced_clp_label, - model_dimension: matrix.coords[model_dimension], - }, - ) + reduced_matrix = matrix.matrix @ relation_matrix + return CalculatedMatrix(reduced_clp_labels, reduced_matrix) def retrieve_clps( model: Model, parameters: ParameterGroup, clp_labels: xr.DataArray, + reduced_clp_labels: xr.DataArray, reduced_clps: xr.DataArray, index: Any | None, ) -> xr.DataArray: if len(model.relations) == 0 and len(model.constraints) == 0: return reduced_clps - clps = xr.DataArray(np.zeros((clp_labels.size), dtype=np.float64), coords=[clp_labels]) - clps.loc[{"clp_label": reduced_clps.coords["clp_label"]}] = reduced_clps.values + clps = np.zeros(len(clp_labels)) + + for i, label in enumerate(reduced_clp_labels): + idx = clp_labels.index(label) + clps[idx] = reduced_clps[i] for relation in model.relations: relation = relation.fill(model, parameters) @@ -160,55 +179,78 @@ def retrieve_clps( and relation.applies(index) and relation.source in clp_labels ): - clps.loc[{"clp_label": relation.target}] = relation.parameter * clps.sel( - clp_label=relation.source - ) - + source_idx = clp_labels.index(relation.source) + target_idx = clp_labels.index(relation.target) + clps[target_idx] = relation.parameter * clps[source_idx] return clps def calculate_clp_penalties( model: Model, parameters: ParameterGroup, - clps: xr.DataArray, - global_dimension: str, + clp_labels: list[list[str]] | list[str], + clps: list[np.ndarray], + global_axis: np.ndarray, ) -> np.ndarray: penalties = [] for penalty in model.clp_area_penalties: - if ( - penalty.source in clps.coords["clp_label"] - and penalty.target in clps.coords["clp_label"] - ): - penalty = penalty.fill(model, parameters) - - source_area = xr.concat( - [ - clps.sel( - { - "clp_label": penalty.source, - global_dimension: slice(interval[0], interval[1]), - } - ) - for interval in penalty.source_intervals - ], - dim=global_dimension, - ) - - target_area = xr.concat( - [ - clps.sel( - { - "clp_label": penalty.target, - global_dimension: slice(interval[0], interval[1]), - } - ) - for interval in penalty.target_intervals - ], - dim=global_dimension, - ) - - area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) - penalties.append(area_penalty * penalty.weight) + penalty = penalty.fill(model, parameters) + source_area = _get_area( + penalty.source, + clp_labels, + clps, + penalty.source_intervals, + global_axis, + ) + + target_area = _get_area( + penalty.target, + clp_labels, + clps, + penalty.target_intervals, + global_axis, + ) + + area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) + + penalties.append(area_penalty * penalty.weight) return np.asarray(penalties) + + +def _get_area( + clp_label: str, + clp_labels: list[list[str]], + clps: list[np.ndarray], + intervals: list[tuple[float, float]], + global_axis: np.ndarray, +) -> np.ndarray: + area = [] + + for interval in intervals: + if interval[0] > global_axis[-1]: + continue + + start_idx, end_idx = get_idx_from_interval(interval, global_axis) + for i in range(start_idx, end_idx + 1): + index_clp_labels = clp_labels[i] if isinstance(clp_labels[0], list) else clp_labels + if clp_label in index_clp_labels: + area.append(clps[i][index_clp_labels.index(clp_label)]) + + return np.asarray(area) # TODO: normalize for distance on global axis + + +def get_idx_from_interval(interval: tuple[float, float], axis: np.ndarray) -> tuple[int, int]: + """Retrieves start and end index of an interval on some axis + Parameters + ---------- + interval : A tuple of floats with begin and end of the interval + axis : Array like object which can be cast to np.array + Returns + ------- + start, end : tuple of int + """ + start = np.abs(axis - interval[0]).argmin() if not np.isinf(interval[0]) else 0 + end = np.abs(axis - interval[1]).argmin() if not np.isinf(interval[1]) else axis.size - 1 + return start, end diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py index 942384ed9..b701b9c0c 100644 --- a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -16,13 +16,10 @@ def calculate_matrix( indices: dict[str, int], **kwargs, ): - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coordinates()[model_dimension] + model_axis = dataset_model.get_model_axis() clp_label = [f"{dataset_model.label}_baseline"] matrix = np.ones((model_axis.size, 1), dtype=np.float64) - return xr.DataArray( - matrix, coords=((model_dimension, model_axis.data), ("clp_label", clp_label)) - ) + return clp_label, matrix def index_dependent(self, dataset: DatasetModel) -> bool: return False diff --git a/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py index ccff81bc1..f19f0d226 100644 --- a/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py @@ -1,5 +1,4 @@ import numpy as np -import xarray as xr from glotaran.analysis.util import calculate_matrix from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex @@ -43,17 +42,17 @@ def test_baseline(): ] ) - time = xr.DataArray(np.asarray(np.arange(0, 50, 1.5))) - pixel = xr.DataArray([0]) + time = np.asarray(np.arange(0, 50, 1.5)) + pixel = np.asarray([0]) coords = {"time": time, "pixel": pixel} dataset_model = model.dataset["dataset1"].fill(model, parameter) dataset_model.overwrite_global_dimension("pixel") dataset_model.set_coordinates(coords) matrix = calculate_matrix(dataset_model, {}) - compartments = matrix.coords["clp_label"] + compartments = matrix.clp_labels assert len(compartments) == 2 - assert compartments[0] == "dataset1_baseline" + assert "dataset1_baseline" in compartments - assert matrix.shape == (time.size, 2) - assert np.all(matrix[:, 0] == 1) + assert matrix.matrix.shape == (time.size, 2) + assert np.all(matrix.matrix[:, 1] == 1) diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index eb22c3f64..8e9cb4715 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -44,9 +44,8 @@ def calculate_matrix( global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) - global_axis = dataset_model.get_coordinates().get(global_dimension).values - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coordinates()[model_dimension].values + global_axis = dataset_model.get_global_axis() + model_axis = dataset_model.get_model_axis() irf = dataset_model.irf @@ -55,9 +54,7 @@ def calculate_matrix( width = self.width.value if self.width is not None else width[0] matrix = _calculate_coherent_artifact_matrix(center, width, model_axis, self.order) - return xr.DataArray( - matrix, coords=((model_dimension, model_axis), ("clp_label", self.compartments())) - ) + return self.compartments(), matrix def compartments(self): return [f"coherent_artifact_{i}" for i in range(1, self.order + 1)] diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py index f09697b98..9a02ce4aa 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py @@ -59,22 +59,22 @@ def test_coherent_artifact(): ] ) - time = xr.DataArray(np.arange(0, 50, 1.5)) - spectral = xr.DataArray([0]) + time = np.arange(0, 50, 1.5) + spectral = np.asarray([0]) coords = {"time": time, "spectral": spectral} dataset_model = model.dataset["dataset1"].fill(model, parameters) dataset_model.overwrite_global_dimension("spectral") dataset_model.set_coordinates(coords) matrix = calculate_matrix(dataset_model, {}) - compartments = matrix.coords["clp_label"].values + compartments = matrix.clp_labels print(compartments) assert len(compartments) == 4 for i in range(1, 4): - assert compartments[i - 1] == f"coherent_artifact_{i}" + assert compartments[i] == f"coherent_artifact_{i}" - assert matrix.shape == (time.size, 4) + assert matrix.matrix.shape == (time.size, 4) clp = xr.DataArray( [[1, 1, 1, 1]], diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index 0ea4a0ad1..25bddf55d 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -89,9 +89,8 @@ def calculate_matrix( global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) - global_axis = dataset_model.get_coordinates().get(global_dimension).values - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coordinates()[model_dimension].values + global_axis = dataset_model.get_global_axis() + model_axis = dataset_model.get_model_axis() # init the matrix size = (model_axis.size, rates.size) @@ -111,7 +110,7 @@ def calculate_matrix( matrix = matrix @ k_matrix.a_matrix(initial_concentration) # done - return xr.DataArray(matrix, coords=((model_dimension, model_axis), ("clp_label", species))) + return species, matrix def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): global_dimension = dataset_model.get_global_dimension() diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index c89514984..32e266128 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -63,8 +63,8 @@ class OneComponentOneChannel: [101e-3, [1, {"vary": False, "non-negative": False}]] ) - time = xr.DataArray(np.arange(0, 50, 1.5)) - pixel = xr.DataArray([0]) + time = np.arange(0, 50, 1.5) + pixel = np.asarray([0]) axis = {"time": time, "pixel": pixel} clp = xr.DataArray([[1]], coords=[("pixel", pixel.data), ("clp_label", ["s1"])]) @@ -115,8 +115,8 @@ class OneComponentOneChannelGaussianIrf: ] ) - time = xr.DataArray(np.arange(0, 50, 1.5)) - pixel = xr.DataArray([0]) + time = np.arange(0, 50, 1.5) + pixel = np.asarray([0]) axis = {"time": time, "pixel": pixel} clp = xr.DataArray([[1]], coords=[("pixel", pixel.data), ("clp_label", ["s1"])]) @@ -179,8 +179,8 @@ class ThreeComponentParallel: "j": [["1", 1, {"vary": False, "non-negative": False}]], } ) - time = xr.DataArray(np.arange(-10, 100, 1.5)) - pixel = xr.DataArray(np.arange(600, 750, 10)) + time = np.arange(-10, 100, 1.5) + pixel = np.arange(600, 750, 10) axis = {"time": time, "pixel": pixel} @@ -254,8 +254,8 @@ class ThreeComponentSequential: } ) - time = xr.DataArray(np.arange(-10, 50, 1.0)) - pixel = xr.DataArray(np.arange(600, 750, 10)) + time = np.arange(-10, 50, 1.0) + pixel = np.arange(600, 750, 10) axis = {"time": time, "pixel": pixel} clp = _create_gaussian_clp( diff --git a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py index e2ca08519..4256cbfa2 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py @@ -2,7 +2,6 @@ import numpy as np import pytest -import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate @@ -92,16 +91,16 @@ class SimpleIrfDispersion: time_p1 = np.linspace(-1, 2, 50, endpoint=False) time_p2 = np.linspace(2, 5, 30, endpoint=False) time_p3 = np.geomspace(5, 10, num=20) - time = xr.DataArray(np.concatenate([time_p1, time_p2, time_p3])) - spectral = xr.DataArray(np.arange(300, 500, 100)) + time = np.array(np.concatenate([time_p1, time_p2, time_p3])) + spectral = np.arange(300, 500, 100) axis = {"time": time, "spectral": spectral} class MultiIrfDispersion: model = load_model(MODEL_MULTI_IRF_DISPERSION, format_name="yml_str") parameters = load_parameters(PARAMETERS_MULTI_IRF_DISPERSION, format_name="yml_str") - time = xr.DataArray(np.arange(-1, 5, 0.2)) - spectral = xr.DataArray(np.arange(300, 500, 100)) + time = np.arange(-1, 5, 0.2) + spectral = np.arange(300, 500, 100) axis = {"time": time, "spectral": spectral} diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index 508e66a49..5b1ff20ad 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -34,8 +34,7 @@ def calculate_matrix( raise ModelError(f"More then one shape defined for compartment '{compartment}'") compartments.append(compartment) - model_dimension = dataset_model.get_model_dimension() - model_axis = dataset_model.get_coordinates()[model_dimension].data + model_axis = dataset_model.get_model_axis() if self.energy_spectrum: model_axis = 1e7 / model_axis @@ -45,9 +44,8 @@ def calculate_matrix( for i, shape in enumerate(self.shape.values()): matrix[:, i] += shape.calculate(model_axis) - return xr.DataArray( - matrix, coords=((model_dimension, model_axis), ("clp_label", compartments)) - ) + + return compartments, matrix def index_dependent(self, dataset: DatasetModel) -> bool: return False diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index 2f2bbd5cc..9e163292d 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -75,15 +75,16 @@ class OneCompartmentModel: spectral_parameters = ParameterGroup.from_list([7, 20000, 800]) - time = xr.DataArray(np.arange(-10, 50, 1.5)) - spectral = xr.DataArray(np.arange(400, 600, 5)) + time = np.arange(-10, 50, 1.5) + spectral = np.arange(400, 600, 5) axis = {"time": time, "spectral": spectral} decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) decay_dataset_model.overwrite_global_dimension("spectral") decay_dataset_model.set_coordinates(axis) - clp = calculate_matrix(decay_dataset_model, {}) - decay_compartments = clp.coords["clp_label"].values + matrix = calculate_matrix(decay_dataset_model, {}) + decay_compartments = matrix.clp_labels + clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) class ThreeCompartmentModel: @@ -172,15 +173,16 @@ class ThreeCompartmentModel: ] ) - time = xr.DataArray(np.arange(-10, 50, 1.5)) - spectral = xr.DataArray(np.arange(400, 600, 5)) + time = np.arange(-10, 50, 1.5) + spectral = np.arange(400, 600, 5) axis = {"time": time, "spectral": spectral} decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) decay_dataset_model.overwrite_global_dimension("spectral") decay_dataset_model.set_coordinates(axis) - clp = calculate_matrix(decay_dataset_model, {}) - decay_compartments = clp.coords["clp_label"].values + matrix = calculate_matrix(decay_dataset_model, {}) + decay_compartments = matrix.clp_labels + clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) @pytest.mark.parametrize( diff --git a/glotaran/builtin/models/spectral/test/__init__.py b/glotaran/builtin/models/spectral/test/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/glotaran/examples/sequential.py b/glotaran/examples/sequential.py index c50d5cf9c..328102d5b 100644 --- a/glotaran/examples/sequential.py +++ b/glotaran/examples/sequential.py @@ -1,5 +1,4 @@ import numpy as np -import xarray as xr from glotaran.analysis.simulation import simulate from glotaran.builtin.megacomplexes.decay import DecayMegacomplex @@ -104,8 +103,8 @@ } ) -_time = xr.DataArray(np.arange(-1, 20, 0.01)) -_spectral = xr.DataArray(np.arange(600, 700, 1.4)) +_time = np.arange(-1, 20, 0.01) +_spectral = np.arange(600, 700, 1.4) dataset = simulate( sim_model, diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py index 2906bfe25..c1c817fd4 100644 --- a/glotaran/model/clp_penalties.py +++ b/glotaran/model/clp_penalties.py @@ -16,9 +16,7 @@ from typing import Any from typing import Sequence - from glotaran.builtin.models.kinetic_spectrum.kinetic_spectrum_model import ( - KineticSpectrumModel, - ) + from glotaran.model.model import Model from glotaran.parameter import ParameterGroup @@ -61,12 +59,12 @@ def applies(interval): return any([applies(i) for i in self.interval]) -def has_spectral_penalties(model: KineticSpectrumModel) -> bool: +def has_spectral_penalties(model: Model) -> bool: return len(model.equal_area_penalties) != 0 def apply_spectral_penalties( - model: KineticSpectrumModel, + model: Model, parameters: ParameterGroup, clp_labels: dict[str, list[str] | list[list[str]]], clps: dict[str, list[np.ndarray]], diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index 2e407cc7e..749db9a0e 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from typing import Generator +import numpy as np import xarray as xr from glotaran.model.item import model_item @@ -110,13 +111,21 @@ def swap_dimensions(self): def set_data(self, data: xr.Dataset) -> DatasetModel: """Sets the dataset model's data.""" - self._data = data + self._coords = {name: dim.values for name, dim in data.coords.items()} + self._data = data.data.values + self._weight = data.weight.values if "weight" in data else None + if self._weight is not None: + self._data = self._data * self._weight return self - def get_data(self) -> xr.Dataset: + def get_data(self) -> np.ndarray: """Gets the dataset model's data.""" return self._data + def get_weight(self) -> np.ndarray: + """Gets the dataset model's weight.""" + return self._weight + def index_dependent(self) -> bool: """Indicates if the dataset model is index dependent.""" if hasattr(self, "_index_dependent"): @@ -131,15 +140,21 @@ def global_model(self) -> bool: """Indicates if the dataset model can model the global dimension.""" return len(self.global_megacomplex) != 0 - def set_coordinates(self, coords: xr.Dataset): + def set_coordinates(self, coords: dict[str, np.ndarray]): """Sets the dataset model's coordinates.""" self._coords = coords - def get_coordinates(self) -> xr.Dataset: + def get_coordinates(self) -> np.ndarray: """Gets the dataset model's coordinates.""" - if hasattr(self, "_coords"): - return self._coords - return self._data.coords + return self._coords + + def get_model_axis(self) -> np.ndarray: + """Gets the dataset model's model axis.""" + return self._coords[self.get_model_dimension()] + + def get_global_axis(self) -> np.ndarray: + """Gets the dataset model's global axis.""" + return self._coords[self.get_global_dimension()] @model_item_validator(False) def ensure_unique_megacomplexes(self, model: Model) -> list[str]: diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index cc862b6ea..4676b1abe 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -309,7 +309,7 @@ def test_items(test_model: Model): def test_fill(test_model: Model, parameter: ParameterGroup): - data = xr.DataArray([[1]], dims=("global", "model")).to_dataset(name="data") + data = xr.DataArray([[1]], coords=(("global", [0]), ("model", [0]))).to_dataset(name="data") dataset = test_model.dataset.get("dataset1").fill(test_model, parameter) dataset.set_data(data) assert [cmplx.label for cmplx in dataset.megacomplex] == ["m1"] diff --git a/glotaran/test/test_spectral_decay.py b/glotaran/test/test_spectral_decay.py index 9c0fb5995..d18b5534b 100644 --- a/glotaran/test/test_spectral_decay.py +++ b/glotaran/test/test_spectral_decay.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate @@ -8,94 +7,6 @@ from glotaran.io import load_parameters from glotaran.project import Scheme -MODEL_1C_BASE = """\ -dataset: - dataset1: &dataset1 - megacomplex: [mc1] - global_megacomplex: [mc2] - initial_concentration: j1 -initial_concentration: - j1: - compartments: [s1] - parameters: ["1"] -megacomplex: - mc1: - type: decay - k_matrix: [k1] - mc2: - type: spectral - shape: - s1: sh1 -k_matrix: - k1: - matrix: - (s1, s1): "2" -shape: - sh1: - type: one -""" -MODEL_1C_NO_IRF = MODEL_1C_BASE - -PARAMETERS_1C_NO_IRF_BASE = """\ -- [1, {"vary": False, "non-negative": False}] -""" - -PARAMETERS_1C_INITIAL = f"""\ -{PARAMETERS_1C_NO_IRF_BASE} -- 101e-4 -""" - -PARAMETERS_1C_WANTED = f"""\ -{PARAMETERS_1C_NO_IRF_BASE} -- 101e-3 -""" - -MODEL_1C_GAUSSIAN_IRF = f"""\ -{MODEL_1C_BASE} -irf: - irf1: - type: spectral-gaussian - center: "3" - width: "4" -dataset: - dataset1: - <<: *dataset1 - irf: irf1 -""" - -PARAMETERS_1C_GAUSSIAN_IRF_INITIAL = f"""\ -{PARAMETERS_1C_NO_IRF_BASE} -- 100e-4 -- 0.1 -- 1 -""" - -PARAMETERS_1C_GAUSSIAN_WANTED = f"""\ -{PARAMETERS_1C_NO_IRF_BASE} -- 101e-3 -- 0.3 -- 2 -""" - - -class OneComponentOneChannel: - model = load_model(MODEL_1C_NO_IRF, format_name="yml_str") - initial_parameters = load_parameters(PARAMETERS_1C_INITIAL, format_name="yml_str") - wanted_parameters = load_parameters(PARAMETERS_1C_WANTED, format_name="yml_str") - time = xr.DataArray(np.arange(0, 50, 1.5)) - spectral = xr.DataArray([0]) - axis = {"time": time, "spectral": spectral} - - -class OneComponentOneChannelGaussianIrf: - model = load_model(MODEL_1C_GAUSSIAN_IRF, format_name="yml_str") - initial_parameters = load_parameters(PARAMETERS_1C_GAUSSIAN_IRF_INITIAL, format_name="yml_str") - wanted_parameters = load_parameters(PARAMETERS_1C_GAUSSIAN_WANTED, format_name="yml_str") - time = xr.DataArray(np.arange(-10, 50, 1.5)) - spectral = xr.DataArray([0]) - axis = {"time": time, "spectral": spectral} - - MODEL_3C_BASE = """\ dataset: dataset1: &dataset1 @@ -227,8 +138,8 @@ class ThreeComponentParallel: model = load_model(MODEL_3C_PARALLEL, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_PARALLEL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_3C_PARALLEL_WANTED, format_name="yml_str") - time = xr.DataArray(np.arange(-10, 100, 1.5)) - spectral = xr.DataArray(np.arange(600, 750, 10)) + time = np.arange(-10, 100, 1.5) + spectral = np.arange(600, 750, 10) axis = {"time": time, "spectral": spectral} @@ -236,16 +147,14 @@ class ThreeComponentSequential: model = load_model(MODEL_3C_SEQUENTIAL, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_SEQUENTIAL, format_name="yml_str") wanted_parameters = load_parameters(PARAMETERS_3C_SIM_SEQUENTIAL, format_name="yml_str") - time = xr.DataArray(np.arange(-10, 50, 1.0)) - spectral = xr.DataArray(np.arange(600, 750, 5.0)) + time = np.arange(-10, 50, 1.0) + spectral = np.arange(600, 750, 5.0) axis = {"time": time, "spectral": spectral} @pytest.mark.parametrize( "suite", [ - OneComponentOneChannel, - OneComponentOneChannelGaussianIrf, ThreeComponentParallel, ThreeComponentSequential, ], diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index 7e4fcb1c0..000c8f418 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -6,7 +6,6 @@ from copy import deepcopy import numpy as np -import xarray as xr from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate @@ -72,11 +71,11 @@ def test_equal_area_penalties(debug=False): optim_spec = OptimizationSpec(nnls=True, max_nfev=999) noise_spec = NoiseSpec(active=True, seed=1, std_dev=1e-8) - wavelengths = xr.DataArray(np.arange(650, 670, 2)) - time_p1 = xr.DataArray(np.linspace(-1, 2, 50, endpoint=False)) - time_p2 = xr.DataArray(np.linspace(2, 10, 30, endpoint=False)) - time_p3 = xr.DataArray(np.geomspace(10, 50, num=20)) - times = xr.DataArray(np.concatenate([time_p1, time_p2, time_p3])) + wavelengths = np.arange(650, 670, 2) + time_p1 = np.linspace(-1, 2, 50, endpoint=False) + time_p2 = np.linspace(2, 10, 30, endpoint=False) + time_p3 = np.geomspace(10, 50, num=20) + times = np.concatenate([time_p1, time_p2, time_p3]) irf_loc = float(times[20]) irf_width = float((times[1] - times[0]) * 10) From 0b86a1511d7a140b6f31c1a87bfa683061cab928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Tue, 20 Jul 2021 21:22:39 +0200 Subject: [PATCH 06/29] =?UTF-8?q?=E2=9C=A8=20Feature:=20Full=20Models=20(#?= =?UTF-8?q?747)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added Full Model functionality (also known as spectrotemporal models) * Added finalizing functions for full models and changed spectral decay test to full model test * Added flake8-print to precommit * Fixed parameter history and set correct optizemet parameter prior to resultdata creation * Don't apply clp relations and constraints on full models * Added auto inference of grouping * Update glotaran/test/test_spectral_decay_full_model.py * Correction of cost calculation Co-authored-by: Joris Snellenburg Co-authored-by: Sebastian Weigand --- .pre-commit-config.yaml | 2 +- benchmark/pytest/analysis/test_problem.py | 15 +- glotaran/analysis/nnls.py | 4 +- glotaran/analysis/optimize.py | 18 +- glotaran/analysis/problem.py | 2 +- glotaran/analysis/problem_ungrouped.py | 134 ++++++++++- glotaran/analysis/simulation.py | 4 +- glotaran/analysis/test/models.py | 68 +++++- glotaran/analysis/test/test_constraints.py | 4 +- glotaran/analysis/test/test_optimization.py | 73 ++++-- glotaran/analysis/test/test_problem.py | 25 +- glotaran/analysis/util.py | 8 +- glotaran/analysis/variable_projection.py | 2 +- .../baseline/baseline_megacomplex.py | 11 +- .../coherent_artifact_megacomplex.py | 31 ++- .../megacomplexes/decay/decay_megacomplex.py | 43 +++- glotaran/builtin/megacomplexes/decay/util.py | 77 +++--- .../spectral/spectral_megacomplex.py | 51 ++-- glotaran/model/dataset_model.py | 26 +- glotaran/model/interval_property.py | 11 +- glotaran/model/megacomplex.py | 8 +- glotaran/model/model.py | 17 +- glotaran/model/test/test_model.py | 48 +++- glotaran/project/scheme.py | 12 +- glotaran/test/test_spectral_decay.py | 112 ++++++++- .../test/test_spectral_decay_full_model.py | 224 ++++++++++++++++++ glotaran/test/test_spectral_penalties.py | 18 +- 27 files changed, 872 insertions(+), 176 deletions(-) create mode 100644 glotaran/test/test_spectral_decay_full_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1491a1924..d22f3a349 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -113,7 +113,7 @@ repos: - id: yesqa types: [file] types_or: [python, pyi] - additional_dependencies: [flake8-docstrings] + additional_dependencies: [flake8-docstrings, flake8-print] - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 diff --git a/benchmark/pytest/analysis/test_problem.py b/benchmark/pytest/analysis/test_problem.py index 265060feb..9c148818f 100644 --- a/benchmark/pytest/analysis/test_problem.py +++ b/benchmark/pytest/analysis/test_problem.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import pytest import xarray as xr @@ -10,6 +14,9 @@ from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +if TYPE_CHECKING: + from glotaran.model import DatasetModel + TEST_AXIS_MODEL_SIZE = 100 TEST_AXIS_MODEL = xr.DataArray(np.arange(0, TEST_AXIS_MODEL_SIZE)) TEST_AXIS_GLOBAL_SIZE = 100 @@ -36,7 +43,13 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): def index_dependent(self, dataset_model): return self.is_index_dependent - def finalize_data(self, dataset_model, data): + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): pass diff --git a/glotaran/analysis/nnls.py b/glotaran/analysis/nnls.py index 3a6ee5b8b..23de3c54c 100644 --- a/glotaran/analysis/nnls.py +++ b/glotaran/analysis/nnls.py @@ -7,9 +7,7 @@ from scipy.optimize import nnls -def residual_nnls( - matrix: np.ndarray, data: np.ndarray -) -> typing.Tuple[typing.List[str], np.ndarray]: +def residual_nnls(matrix: np.ndarray, data: np.ndarray) -> typing.Tuple[np.ndarray, np.ndarray]: """Calculate the conditionally linear parameters and residual with the nnls method. nnls stands for 'non-negative least-squares'. diff --git a/glotaran/analysis/optimize.py b/glotaran/analysis/optimize.py index a53dbd5be..a07169f17 100644 --- a/glotaran/analysis/optimize.py +++ b/glotaran/analysis/optimize.py @@ -19,12 +19,14 @@ } -def optimize(scheme: Scheme, verbose: bool = True) -> Result: - problem = GroupedProblem(scheme) if scheme.group else UngroupedProblem(scheme) - return optimize_problem(problem, verbose=verbose) +def optimize(scheme: Scheme, verbose: bool = True, raise_exception: bool = False) -> Result: + problem = GroupedProblem(scheme) if scheme.is_grouped() else UngroupedProblem(scheme) + return optimize_problem(problem, verbose=verbose, raise_exception=raise_exception) -def optimize_problem(problem: Problem, verbose: bool = True) -> Result: +def optimize_problem( + problem: Problem, verbose: bool = True, raise_exception: bool = False +) -> Result: if problem.scheme.optimization_method not in SUPPORTED_METHODS: raise ValueError( @@ -61,12 +63,12 @@ def optimize_problem(problem: Problem, verbose: bool = True) -> Result: ) termination_reason = ls_result.message except Exception as e: + if raise_exception: + raise e warn(f"Optimization failed:\n\n{e}") termination_reason = str(e) ls_result = None - problem.save_parameters_for_history() - return _create_result(problem, ls_result, free_parameter_labels, termination_reason) @@ -101,7 +103,9 @@ def _create_result( root_mean_square_error = np.sqrt(reduced_chi_square) if success else None jacobian = ls_result.jac if success else None - problem.save_parameters_for_history() + if success: + problem.parameters.set_from_label_and_value_arrays(free_parameter_labels, ls_result.x) + problem.reset() history_index = None if success else -2 data = problem.create_result_data(history_index=history_index) # the optimized parameters are those of the last run if the optimization has crashed diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index 400b179e3..bc7463499 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -209,7 +209,7 @@ def full_penalty(self) -> np.ndarray: @property def cost(self) -> float: - return np.sum(self._full_penalty) + return 0.5 * np.dot(self.full_penalty, self.full_penalty) def save_parameters_for_history(self): self._parameter_history.append(self._parameters) diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 1acdf5cd6..10b03eff9 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -12,11 +12,37 @@ from glotaran.analysis.util import reduce_matrix from glotaran.analysis.util import retrieve_clps from glotaran.model import DatasetModel +from glotaran.project import Scheme class UngroupedProblem(Problem): """Represents a problem where the data is not grouped.""" + def __init__(self, scheme: Scheme): + """Initializes the Problem class from a scheme (:class:`glotaran.analysis.scheme.Scheme`) + + Args: + scheme (Scheme): An instance of :class:`glotaran.analysis.scheme.Scheme` + which defines your model, parameters, and data + """ + super().__init__(scheme=scheme) + + self._global_matrices = {} + self._flattened_data = {} + self._flattened_weights = {} + for label, dataset_model in self.dataset_models.items(): + if dataset_model.has_global_model(): + self._flattened_data[label] = dataset_model.get_data().T.flatten() + weight = dataset_model.get_weight() + if weight is not None: + weight = weight.T.flatten() + self._flattened_data[label] *= weight + self._flattened_weight[label] = weight + + @property + def global_matrices(self) -> dict[str, CalculatedMatrix]: + return self._global_matrices + def calculate_matrices( self, ) -> tuple[ @@ -28,6 +54,7 @@ def calculate_matrices( raise ParameterError self._matrices = {} + self._global_matrices = {} self._reduced_matrices = {} for label, dataset_model in self.dataset_models.items(): @@ -37,6 +64,9 @@ def calculate_matrices( else: self._calculate_index_independent_matrix(label, dataset_model) + if dataset_model.has_global_model(): + self._calculate_global_matrix(label, dataset_model) + return self._matrices, self._reduced_matrices def _calculate_index_dependent_matrix(self, label: str, dataset_model: DatasetModel): @@ -48,15 +78,20 @@ def _calculate_index_dependent_matrix(self, label: str, dataset_model: DatasetMo {dataset_model.get_global_dimension(): i}, ) self._matrices[label].append(matrix) - reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, index) - self._reduced_matrices[label].append(reduced_matrix) + if not dataset_model.has_global_model(): + reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, index) + self._reduced_matrices[label].append(reduced_matrix) def _calculate_index_independent_matrix(self, label: str, dataset_model: DatasetModel): - matrix = calculate_matrix(dataset_model, {}) self._matrices[label] = matrix - reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, None) - self._reduced_matrices[label] = reduced_matrix + if not dataset_model.has_global_model(): + reduced_matrix = reduce_matrix(matrix, self.model, self.parameters, None) + self._reduced_matrices[label] = reduced_matrix + + def _calculate_global_matrix(self, label: str, dataset_model: DatasetModel): + matrix = calculate_matrix(dataset_model, {}, as_global_model=True) + self._global_matrices[label] = matrix def calculate_residual( self, @@ -75,7 +110,10 @@ def calculate_residual( self._additional_penalty = [] for label, dataset_model in self._dataset_models.items(): - self._calculate_residual(label, dataset_model) + if dataset_model.has_global_model(): + self._calculate_full_model_residual(label, dataset_model) + else: + self._calculate_residual(label, dataset_model) self._additional_penalty = ( np.concatenate(self._additional_penalty) if len(self._additional_penalty) != 0 else [] @@ -135,6 +173,30 @@ def _calculate_residual(self, label: str, dataset_model: DatasetModel): if additional_penalty.size != 0: self._additional_penalty.append(additional_penalty) + def _calculate_full_model_residual(self, label: str, dataset_model: DatasetModel): + + model_matrix = self.matrices[label] + global_matrix = self.global_matrices[label].matrix + + if dataset_model.index_dependent(): + matrix = np.concatenate( + [ + np.kron(global_matrix[i, :], model_matrix[i].matrix) + for i in range(global_matrix.shape[0]) + ] + ) + else: + matrix = np.kron(global_matrix, model_matrix.matrix) + weight = self._flattened_weights.get(label) + if weight is not None: + apply_weight(matrix, weight) + data = self._flattened_data[label] + self._clps[label], self._weighted_residuals[label] = self._residual_function(matrix, data) + + self._residuals[label] = self._weighted_residuals[label] + if weight is not None: + self._residuals[label] /= weight + def _get_clp_labels(self, label: str, index: int = 0): return ( self.matrices[label][index].clp_labels @@ -158,7 +220,11 @@ def create_index_dependent_result_dataset(self, label: str, dataset: xr.Dataset) np.asarray([m.matrix for m in self.matrices[label]]), ) - self._add_residual_and_full_clp_to_dataset(label, dataset) + if self.dataset_models[label].has_global_model(): + self._add_global_matrix_to_dataset(label, dataset) + self._add_full_model_residual_and_clp_to_dataset(label, dataset) + else: + self._add_residual_and_clp_to_dataset(label, dataset) return dataset @@ -178,11 +244,27 @@ def create_index_independent_result_dataset( matrix.matrix, ) - self._add_residual_and_full_clp_to_dataset(label, dataset) + if self.dataset_models[label].has_global_model(): + self._add_global_matrix_to_dataset(label, dataset) + self._add_full_model_residual_and_clp_to_dataset(label, dataset) + else: + self._add_residual_and_clp_to_dataset(label, dataset) return dataset - def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset): + def _add_global_matrix_to_dataset(self, label: str, dataset: xr.Dataset) -> xr.Dataset: + matrix = self.global_matrices[label] + dataset.coords["global_clp_label"] = matrix.clp_labels + global_dimension = self.dataset_models[label].get_global_dimension() + dataset["global_matrix"] = ( + ( + (global_dimension), + ("global_clp_label"), + ), + matrix.matrix, + ) + + def _add_residual_and_clp_to_dataset(self, label: str, dataset: xr.Dataset): model_dimension = self.dataset_models[label].get_model_dimension() global_dimension = self.dataset_models[label].get_global_dimension() dataset["clp"] = ( @@ -207,12 +289,44 @@ def _add_residual_and_full_clp_to_dataset(self, label: str, dataset: xr.Dataset) np.transpose(np.asarray(self.residuals[label])), ) + def _add_full_model_residual_and_clp_to_dataset(self, label: str, dataset: xr.Dataset): + model_dimension = self.dataset_models[label].get_model_dimension() + global_dimension = self.dataset_models[label].get_global_dimension() + dataset["clp"] = ( + ( + ("global_clp_label"), + ("clp_label"), + ), + self.clps[label].reshape( + (dataset.coords["global_clp_label"].size, dataset.coords["clp_label"].size) + ), + ) + dataset["weighted_residual"] = ( + ( + (model_dimension), + (global_dimension), + ), + self.weighted_residuals[label].T.reshape(dataset.data.shape), + ) + dataset["residual"] = ( + ( + (model_dimension), + (global_dimension), + ), + self.residuals[label].T.reshape(dataset.data.shape), + ) + @property def full_penalty(self) -> np.ndarray: if self._full_penalty is None: residuals = self.weighted_residuals additional_penalty = self.additional_penalty - residuals = [np.concatenate(residuals[label]) for label in residuals.keys()] + residuals = [ + np.concatenate(residuals[label]) + if isinstance(residuals[label], list) + else residuals[label] + for label in residuals.keys() + ] self._full_penalty = ( np.concatenate((np.concatenate(residuals), additional_penalty)) diff --git a/glotaran/analysis/simulation.py b/glotaran/analysis/simulation.py index 80221b909..2bbdc0b00 100644 --- a/glotaran/analysis/simulation.py +++ b/glotaran/analysis/simulation.py @@ -49,7 +49,7 @@ def simulate( dataset_model = model.dataset[dataset].fill(model, parameters) dataset_model.set_coordinates(coordinates) - if dataset_model.global_model(): + if dataset_model.has_global_model(): result = simulate_global_model( dataset_model, parameters, @@ -131,7 +131,7 @@ def simulate_global_model( if any(m.index_dependent(dataset_model) for m in dataset_model.global_megacomplex): raise ValueError("Index dependent models for global dimension are not supported.") - global_matrix = calculate_matrix(dataset_model, {}, global_model=True) + global_matrix = calculate_matrix(dataset_model, {}, as_global_model=True) global_clp_labels = global_matrix.clp_labels global_matrix = xr.DataArray( global_matrix.matrix.T, diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 87e96f611..004e3746e 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -89,7 +89,13 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): def index_dependent(self, dataset_model): return self.is_index_dependent - def finalize_data(self, dataset_model, data): + def finalize_data( + self, + dataset_model, + dataset, + is_full_model: bool = False, + as_global: bool = False, + ): pass @@ -139,6 +145,15 @@ def calculate_matrix(self, dataset_model, indices, **kwargs): def index_dependent(self, dataset_model): return False + def finalize_data( + self, + dataset_model, + dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + pass + class DecayModel(Model): @classmethod @@ -306,7 +321,6 @@ class MultichannelMulticomponentDecay: sim_model = DecayModel.from_dict( { - # "compartment": ["s1", "s2", "s3", "s4"], "megacomplex": { "m1": {"is_index_dependent": False}, "m2": { @@ -327,7 +341,6 @@ class MultichannelMulticomponentDecay: ) model = DecayModel.from_dict( { - # "compartment": ["s1", "s2", "s3", "s4"], "megacomplex": {"m1": {"is_index_dependent": False}}, "dataset": { "dataset1": { @@ -337,3 +350,52 @@ class MultichannelMulticomponentDecay: }, } ) + + +class FullModel: + model = DecayModel.from_dict( + { + "megacomplex": { + "m1": {"is_index_dependent": False}, + "m2": { + "type": "global_complex_shaped", + "location": ["loc.1", "loc.2", "loc.3", "loc.4"], + "delta": ["del.1", "del.2", "del.3", "del.4"], + "amplitude": ["amp.1", "amp.2", "amp.3", "amp.4"], + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "kinetic": ["k.1", "k.2", "k.3", "k.4"], + } + }, + } + ) + parameters = ParameterGroup.from_dict( + { + "k": [0.006, 0.003, 0.0003, 0.03], + "loc": [ + ["1", 14705], + ["2", 13513], + ["3", 14492], + ["4", 14388], + ], + "amp": [ + ["1", 1], + ["2", 2], + ["3", 5], + ["4", 20], + ], + "del": [ + ["1", 400], + ["2", 100], + ["3", 300], + ["4", 200], + ], + } + ) + global_axis = np.arange(12820, 15120, 50) + model_axis = np.arange(0, 150, 1.5) + coordinates = {"global": global_axis, "model": model_axis} diff --git a/glotaran/analysis/test/test_constraints.py b/glotaran/analysis/test/test_constraints.py index 0914e892c..07b0c9611 100644 --- a/glotaran/analysis/test/test_constraints.py +++ b/glotaran/analysis/test/test_constraints.py @@ -17,7 +17,7 @@ def test_constraint(index_dependent, grouped): model.megacomplex["m1"].is_index_dependent = index_dependent model.constraints.append(ZeroConstraint.from_dict({"target": "s2"})) - print("grouped", grouped, "index_dependent", index_dependent) # T001 + print("grouped", grouped, "index_dependent", index_dependent) # noqa T001 dataset = simulate( suite.sim_model, "dataset1", @@ -36,7 +36,7 @@ def test_constraint(index_dependent, grouped): matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] result_data = problem.create_result_data() - print(result_data) # T001 + print(result_data) # noqa T001 clps = result_data["dataset1"].clp assert "s2" not in reduced_matrix.clp_labels diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 05e8f2050..80d4268c9 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -4,6 +4,7 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate +from glotaran.analysis.test.models import FullModel from glotaran.analysis.test.models import MultichannelMulticomponentDecay from glotaran.analysis.test.models import OneCompartmentDecay from glotaran.analysis.test.models import ThreeDatasetDecay @@ -31,26 +32,26 @@ def test_optimization(suite, index_dependent, grouped, weight, method): model.megacomplex["m1"].is_index_dependent = index_dependent - print("Grouped:", grouped) - print("Index dependent:", index_dependent) + print("Grouped:", grouped) # noqa T001 + print("Index dependent:", index_dependent) # noqa T001 sim_model = suite.sim_model sim_model.megacomplex["m1"].is_index_dependent = index_dependent - print(model.validate()) + print(model.validate()) # noqa T001 assert model.valid() - print(sim_model.validate()) + print(sim_model.validate()) # noqa T001 assert sim_model.valid() wanted_parameters = suite.wanted_parameters - print(wanted_parameters) - print(sim_model.validate(wanted_parameters)) + print(wanted_parameters) # noqa T001 + print(sim_model.validate(wanted_parameters)) # noqa T001 assert sim_model.valid(wanted_parameters) initial_parameters = suite.initial_parameters - print(initial_parameters) - print(model.validate(initial_parameters)) + print(initial_parameters) # noqa T001 + print(model.validate(initial_parameters)) # noqa T001 assert model.valid(initial_parameters) assert ( model.dataset["dataset1"].fill(model, initial_parameters).index_dependent() @@ -69,9 +70,9 @@ def test_optimization(suite, index_dependent, grouped, weight, method): wanted_parameters, {"global": global_axis, "model": model_axis}, ) - print(f"Dataset {i+1}") - print("=============") - print(dataset) + print(f"Dataset {i+1}") # noqa T001 + print("=============") # noqa T001 + print(dataset) # noqa T001 if hasattr(suite, "scale"): dataset["data"] /= suite.scale @@ -95,8 +96,8 @@ def test_optimization(suite, index_dependent, grouped, weight, method): optimization_method=method, ) - result = optimize(scheme) - print(result.optimized_parameters) + result = optimize(scheme, raise_exception=True) + print(result.optimized_parameters) # noqa T001 assert result.success optimized_scheme = result.get_scheme() assert result.optimized_parameters == optimized_scheme.parameters @@ -110,9 +111,9 @@ def test_optimization(suite, index_dependent, grouped, weight, method): for i, dataset in enumerate(data.values()): resultdata = result.data[f"dataset{i+1}"] - print(f"Result Data {i+1}") - print("=================") - print(resultdata) + print(f"Result Data {i+1}") # noqa T001 + print("=================") # noqa T001 + print(resultdata) # noqa T001 assert "residual" in resultdata assert "residual_left_singular_vectors" in resultdata assert "residual_right_singular_vectors" in resultdata @@ -120,7 +121,7 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert np.array_equal(dataset.coords["model"], resultdata.coords["model"]) assert np.array_equal(dataset.coords["global"], resultdata.coords["global"]) assert dataset.data.shape == resultdata.data.shape - print(dataset.data[0, 0], resultdata.data[0, 0]) + print(dataset.data[0, 0], resultdata.data[0, 0]) # noqa T001 assert np.allclose(dataset.data, resultdata.data) if weight: assert "weight" in resultdata @@ -128,3 +129,41 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert "weighted_residual_left_singular_vectors" in resultdata assert "weighted_residual_right_singular_vectors" in resultdata assert "weighted_residual_singular_values" in resultdata + + +@pytest.mark.parametrize("index_dependent", [True, False]) +def test_optimization_full_model(index_dependent): + model = FullModel.model + model.megacomplex["m1"].is_index_dependent = index_dependent + + print(model.validate()) # noqa T001 + assert model.valid() + + parameters = FullModel.parameters + assert model.valid(parameters) + + dataset = simulate(model, "dataset1", parameters, FullModel.coordinates) + + scheme = Scheme( + model=model, + parameters=parameters, + data={"dataset1": dataset}, + maximum_number_function_evaluations=10, + group=False, + ) + + result = optimize(scheme, raise_exception=True) + assert result.success + optimized_scheme = result.get_scheme() + assert result.optimized_parameters == optimized_scheme.parameters + + result_data = result.data["dataset1"] + assert "fitted_data" in result_data + for label, param in result.optimized_parameters.all(): + if param.vary: + assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) + + clp = result_data.clp + print(clp) # noqa T001 + assert clp.shape == (4, 4) + assert all(np.isclose(1.0, c) for c in np.diagonal(clp)) diff --git a/glotaran/analysis/test/test_problem.py b/glotaran/analysis/test/test_problem.py index 08344beeb..241618b2a 100644 --- a/glotaran/analysis/test/test_problem.py +++ b/glotaran/analysis/test/test_problem.py @@ -8,6 +8,7 @@ from glotaran.analysis.problem_grouped import GroupedProblem from glotaran.analysis.problem_ungrouped import UngroupedProblem from glotaran.analysis.simulation import simulate +from glotaran.analysis.test.models import FullModel from glotaran.analysis.test.models import MultichannelMulticomponentDecay as suite from glotaran.analysis.test.models import SimpleTestModel from glotaran.analysis.util import CalculatedMatrix @@ -146,7 +147,7 @@ def test_prepare_data(): ], } model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) # T001 + print(model.validate()) # T001 # noqa T001 assert model.valid() parameters = ParameterGroup.from_list([]) @@ -164,7 +165,7 @@ def test_prepare_data(): problem = Problem(scheme) data = problem.data["dataset1"] - print(data) # T001 + print(data) # noqa T001 assert "data" in data assert "weight" in data @@ -180,7 +181,7 @@ def test_prepare_data(): } ) model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) # T001 + print(model.validate()) # T001 # noqa T001 assert model.valid() scheme = Scheme(model, parameters, {"dataset1": dataset}) @@ -197,3 +198,21 @@ def test_prepare_data(): " because weight is already supplied by dataset.", ): Problem(Scheme(model, parameters, {"dataset1": data})) + + +def test_full_model_problem(): + dataset = simulate(FullModel.model, "dataset1", FullModel.parameters, FullModel.coordinates) + scheme = Scheme( + model=FullModel.model, parameters=FullModel.parameters, data={"dataset1": dataset} + ) + problem = UngroupedProblem(scheme) + + result = problem.create_result_data()["dataset1"] + assert "global_matrix" in result + assert "global_clp_label" in result + + clp = result.clp + + assert clp.shape == (4, 4) + print(np.diagonal(clp)) # noqa T001 + assert all(np.isclose(1.0, c) for c in np.diagonal(clp)) diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index 342d93c60..25b44c3b2 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -48,8 +48,8 @@ def get_min_max_from_interval(interval, axis): def calculate_matrix( dataset_model: DatasetModel, - indices: dict[str, int] | None, - global_model: bool = False, + indices: dict[str, int], + as_global_model: bool = False, ) -> CalculatedMatrix: clp_labels = None @@ -57,7 +57,7 @@ def calculate_matrix( megacomplex_iterator = dataset_model.iterate_megacomplexes - if global_model: + if as_global_model: megacomplex_iterator = dataset_model.iterate_global_megacomplexes dataset_model.swap_dimensions() @@ -81,7 +81,7 @@ def calculate_matrix( clp_labels = tmp_clp_labels matrix = tmp_matrix - if global_model: + if as_global_model: dataset_model.swap_dimensions() return CalculatedMatrix(clp_labels, matrix) diff --git a/glotaran/analysis/variable_projection.py b/glotaran/analysis/variable_projection.py index 9bb420a64..2f9068795 100644 --- a/glotaran/analysis/variable_projection.py +++ b/glotaran/analysis/variable_projection.py @@ -9,7 +9,7 @@ def residual_variable_projection( matrix: np.ndarray, data: np.ndarray -) -> typing.Tuple[typing.List[str], np.ndarray]: +) -> typing.Tuple[np.ndarray, np.ndarray]: """Calculates the conditionally linear parameters and residual with the variable projection method. diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py index b701b9c0c..9dbadbe96 100644 --- a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -24,5 +24,12 @@ def calculate_matrix( def index_dependent(self, dataset: DatasetModel) -> bool: return False - def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): - data[f"{dataset_model.label}_baseline"] = data.clp.sel(clp_label="baseline") + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + if not is_full_model: + dataset[f"{dataset_model.label}_baseline"] = dataset.clp.sel(clp_label="baseline") diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index 8e9cb4715..4f113a929 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -62,18 +62,25 @@ def compartments(self): def index_dependent(self, dataset: DatasetModel) -> bool: return False - def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): - global_dimension = dataset_model.get_global_dimension() - model_dimension = dataset_model.get_model_dimension() - data.coords["coherent_artifact_order"] = list(range(1, self.order + 1)) - data["coherent_artifact_concentration"] = ( - (model_dimension, "coherent_artifact_order"), - data.matrix.sel(clp_label=self.compartments()).values, - ) - data["coherent_artifact_associated_spectra"] = ( - (global_dimension, "coherent_artifact_order"), - data.clp.sel(clp_label=self.compartments()).values, - ) + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + if not is_full_model: + global_dimension = dataset_model.get_global_dimension() + model_dimension = dataset_model.get_model_dimension() + dataset.coords["coherent_artifact_order"] = np.arange(1, self.order + 1) + dataset["coherent_artifact_concentration"] = ( + (model_dimension, "coherent_artifact_order"), + dataset.matrix.sel(clp_label=self.compartments()).values, + ) + dataset["coherent_artifact_associated_spectra"] = ( + (global_dimension, "coherent_artifact_order"), + dataset.clp.sel(clp_label=self.compartments()).values, + ) @nb.jit(nopython=True, parallel=True) diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index 25bddf55d..e57f31f5e 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -112,19 +112,40 @@ def calculate_matrix( # done return species, matrix - def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): global_dimension = dataset_model.get_global_dimension() name = "images" if global_dimension == "pixel" else "spectra" - if "species" not in data.coords: + species_dimension = "decay_species" if as_global else "species" + if species_dimension not in dataset.coords: # We are the first Decay complex called and add SAD for all decay megacomplexes - retrieve_species_associated_data(dataset_model, data, global_dimension, name) - if isinstance(dataset_model.irf, IrfMultiGaussian) and "irf" not in data: - retrieve_irf(dataset_model, data, global_dimension) + retrieve_species_associated_data( + dataset_model, + dataset, + species_dimension, + global_dimension, + name, + is_full_model, + as_global, + ) + if isinstance(dataset_model.irf, IrfMultiGaussian) and "irf" not in dataset: + retrieve_irf(dataset_model, dataset, global_dimension) - multiple_complexes = ( - len([m for m in dataset_model.megacomplex if isinstance(m, DecayMegacomplex)]) > 1 - ) - retrieve_decay_associated_data( - self, dataset_model, data, global_dimension, name, multiple_complexes - ) + if not is_full_model: + multiple_complexes = ( + len([m for m in dataset_model.megacomplex if isinstance(m, DecayMegacomplex)]) > 1 + ) + retrieve_decay_associated_data( + self, + dataset_model, + dataset, + global_dimension, + name, + multiple_complexes, + ) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index 7c22c768a..b984fe4cd 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -106,45 +106,56 @@ def calculate_decay_matrix_gaussian_irf( def retrieve_species_associated_data( - dataset_model: DatasetModel, data: xr.Dataset, global_dimension: str, name: str + dataset_model: DatasetModel, + dataset: xr.Dataset, + species_dimension: str, + global_dimension: str, + name: str, + is_full_model: bool, + as_global: bool, ): species = dataset_model.initial_concentration.compartments model_dimension = dataset_model.get_model_dimension() + if as_global: + model_dimension, global_dimension = global_dimension, model_dimension + dataset.coords[species_dimension] = species + matrix = dataset.global_matrix if as_global else dataset.matrix + clp_dim = "global_clp_label" if as_global else "clp_label" - data.coords["species"] = species - data[f"species_associated_{name}"] = ( - ( - global_dimension, - "species", - ), - data.clp.sel(clp_label=species).data, - ) - - if len(data.matrix.shape) == 3: + if len(dataset.matrix.shape) == 3: # index dependent - data["species_concentration"] = ( + dataset["species_concentration"] = ( ( global_dimension, model_dimension, - "species", + species_dimension, ), - data.matrix.sel(clp_label=species).values, + matrix.sel({clp_dim: species}).values, ) else: # index independent - data["species_concentration"] = ( + dataset["species_concentration"] = ( ( model_dimension, - "species", + species_dimension, + ), + matrix.sel({clp_dim: species}).values, + ) + + if not is_full_model: + dataset[f"species_associated_{name}"] = ( + ( + global_dimension, + species_dimension, ), - data.matrix.sel(clp_label=species).values, + dataset.clp.sel(clp_label=species).data, ) def retrieve_decay_associated_data( megacomplex: DecayMegacomplex, dataset_model: DatasetModel, - data: xr.Dataset, + dataset: xr.Dataset, global_dimension: str, name: str, multiple_complexes: bool, @@ -160,11 +171,11 @@ def retrieve_decay_associated_data( rates = k_matrix.rates(dataset_model.initial_concentration) lifetimes = 1 / rates - das = data[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T + das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T component_coords = {"rate": ("component", rates), "lifetime": ("component", lifetimes)} das_coords = component_coords.copy() - das_coords[global_dimension] = data.coords[global_dimension] + das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}" das = xr.DataArray(das, dims=(global_dimension, "component"), coords=das_coords) @@ -189,32 +200,34 @@ def retrieve_decay_associated_data( k_matrix_name = f"k_matrix_{megacomplex.label}" k_matrix_reduced_name = f"k_matrix_reduced_{megacomplex.label}" - data[das_name] = das - data[a_matrix_name] = a_matrix - data[k_matrix_name] = k_matrix - data[k_matrix_reduced_name] = k_matrix_reduced + dataset[das_name] = das + dataset[a_matrix_name] = a_matrix + dataset[k_matrix_name] = k_matrix + dataset[k_matrix_reduced_name] = k_matrix_reduced -def retrieve_irf(dataset_model: DatasetModel, data: xr.Dataset, global_dimension: str): +def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str): irf = dataset_model.irf model_dimension = dataset_model.get_model_dimension() - data["irf"] = ( + dataset["irf"] = ( (model_dimension), irf.calculate( index=0, - global_axis=data.coords[global_dimension].values, - model_axis=data.coords[model_dimension].values, + global_axis=dataset.coords[global_dimension].values, + model_axis=dataset.coords[model_dimension].values, ).data, ) center = irf.center if isinstance(irf.center, list) else [irf.center] width = irf.width if isinstance(irf.width, list) else [irf.width] - data["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] - data["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] + dataset["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] + dataset["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: - for i, dispersion in enumerate(irf.calculate_dispersion(data.coords["spectral"].values)): - data[f"center_dispersion_{i+1}"] = ( + for i, dispersion in enumerate( + irf.calculate_dispersion(dataset.coords["spectral"].values) + ): + dataset[f"center_dispersion_{i+1}"] = ( global_dimension, dispersion, ) diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index 5b1ff20ad..b25ec7151 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -50,29 +50,42 @@ def calculate_matrix( def index_dependent(self, dataset: DatasetModel) -> bool: return False - def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): - if "species" in data.coords: + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + species_dimension = "spectral_species" if as_global else "species" + if species_dimension in dataset.coords: return species = [] - for megacomplex in dataset_model.megacomplex: # noqa F402 - if isinstance(megacomplex, SpectralMegacomplex): - species += [ - compartment for compartment in megacomplex.shape if compartment not in species - ] - - data.coords["species"] = species - data["species_spectra"] = ( - ( - dataset_model.get_model_dimension(), - "species", - ), - data.matrix.sel(clp_label=species).values, + megacomplexes = ( + dataset_model.global_megacomplex if as_global else dataset_model.megacomplex ) - data["species_associated_concentrations"] = ( + for m in megacomplexes: + if isinstance(m, SpectralMegacomplex): + species += [compartment for compartment in m.shape if compartment not in species] + + dataset.coords[species_dimension] = species + matrix = dataset.global_matrix if as_global else dataset.matrix + clp_dim = "global_clp_label" if as_global else "clp_label" + dataset["species_spectra"] = ( ( - dataset_model.get_global_dimension(), - "species", + dataset_model.get_model_dimension() + if not as_global + else dataset_model.get_global_dimension(), + species_dimension, ), - data.clp.sel(clp_label=species).data, + matrix.sel({clp_dim: species}).values, ) + if not is_full_model: + dataset["species_associated_concentrations"] = ( + ( + dataset_model.get_global_dimension(), + species_dimension, + ), + dataset.clp.sel(clp_label=species).data, + ) diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index 749db9a0e..d86a99ad3 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -63,9 +63,15 @@ def get_model_dimension(self) -> str: ) return self._model_dimension - def finalize_data(self, data: xr.Dataset): + def finalize_data(self, dataset: xr.Dataset): + is_full_model = self.has_global_model() for megacomplex in self.megacomplex: - megacomplex.finalize_data(self, data) + megacomplex.finalize_data(self, dataset, is_full_model=is_full_model) + if is_full_model: + for megacomplex in self.global_megacomplex: + megacomplex.finalize_data( + self, dataset, is_full_model=is_full_model, as_global=True + ) def overwrite_model_dimension(self, model_dimension: str): """Overwrites the dataset model's model dimension.""" @@ -79,7 +85,7 @@ def overwrite_model_dimension(self, model_dimension: str): def get_global_dimension(self) -> str: """Returns the dataset model's global dimension.""" if not hasattr(self, "_global_dimension"): - if self.global_model(): + if self.has_global_model(): if isinstance(self.global_megacomplex[0], str): raise ValueError(f"Dataset descriptor '{self.label}' was not filled") self._global_dimension = self.global_megacomplex[0].dimension @@ -109,11 +115,13 @@ def swap_dimensions(self): self.overwrite_global_dimension(global_dimension) self.overwrite_model_dimension(model_dimension) - def set_data(self, data: xr.Dataset) -> DatasetModel: + def set_data(self, dataset: xr.Dataset) -> DatasetModel: """Sets the dataset model's data.""" - self._coords = {name: dim.values for name, dim in data.coords.items()} - self._data = data.data.values - self._weight = data.weight.values if "weight" in data else None + self._coords: dict[str, np.ndarray] = { + name: dim.values for name, dim in dataset.coords.items() + } + self._data: np.ndarray = dataset.data.values + self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None if self._weight is not None: self._data = self._data * self._weight return self @@ -122,7 +130,7 @@ def get_data(self) -> np.ndarray: """Gets the dataset model's data.""" return self._data - def get_weight(self) -> np.ndarray: + def get_weight(self) -> np.ndarray | None: """Gets the dataset model's weight.""" return self._weight @@ -136,7 +144,7 @@ def overwrite_index_dependent(self, index_dependent: bool): """Overrides the index dependency of the dataset""" self._index_dependent = index_dependent - def global_model(self) -> bool: + def has_global_model(self) -> bool: """Indicates if the dataset model can model the global dimension.""" return len(self.global_megacomplex) != 0 diff --git a/glotaran/model/interval_property.py b/glotaran/model/interval_property.py index 78b7d5e40..0b3d43bd2 100644 --- a/glotaran/model/interval_property.py +++ b/glotaran/model/interval_property.py @@ -1,7 +1,6 @@ """Helper functions.""" from __future__ import annotations -from typing import Any from typing import List from typing import Tuple @@ -10,7 +9,7 @@ @model_item( properties={ - "interval": {"type": List[Tuple[Any, Any]], "default": None, "allow_none": True}, + "interval": {"type": List[Tuple[float, float]], "default": None, "allow_none": True}, }, has_label=False, ) @@ -20,13 +19,13 @@ class IntervalProperty: :math:`source = parameter * target`. """ - def applies(self, index: Any) -> bool: + def applies(self, value: float) -> bool: """ Returns true if the index is in one of the intervals. Parameters ---------- - index : + value : float Returns ------- @@ -37,8 +36,8 @@ def applies(self, index: Any) -> bool: return True def applies(interval): - return interval[0] <= index <= interval[1] + return interval[0] <= value <= interval[1] if isinstance(self.interval, tuple): return applies(self.interval) - return not any([applies(i) for i in self.interval]) + return any([applies(i) for i in self.interval]) diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 3e568c6ad..471dcf4ac 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -115,7 +115,13 @@ def calculate_matrix( def index_dependent(self, dataset_model: DatasetModel) -> bool: raise NotImplementedError - def finalize_data(self, dataset_model: DatasetModel, data: xr.Dataset): + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): raise NotImplementedError @classmethod diff --git a/glotaran/model/model.py b/glotaran/model/model.py index b84427142..60ec800b0 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -5,6 +5,8 @@ from typing import List from warnings import warn +import xarray as xr + from glotaran.model.clp_penalties import EqualAreaPenalty from glotaran.model.constraint import Constraint from glotaran.model.dataset_model import create_dataset_model_type @@ -214,10 +216,23 @@ def global_megacomplex(self) -> dict[str, Megacomplex]: """Alias for `glotaran.model.megacomplex`. Needed internally.""" return self.megacomplex - def need_index_dependent(self): + def need_index_dependent(self) -> bool: """Returns true if e.g. relations with intervals are present.""" return any(i.interval is not None for i in self.constraints + self.relations) + def is_groupable(self, parameters: ParameterGroup, data: dict[str, xr.DataArray]) -> bool: + if any(d.has_global_model() for d in self.dataset.values()): + return False + global_dimensions = { + d.fill(self, parameters).set_data(data[k]).get_global_dimension() + for k, d in self.dataset.items() + } + model_dimensions = { + d.fill(self, parameters).set_data(data[k]).get_model_dimension() + for k, d in self.dataset.items() + } + return len(global_dimensions) == 1 and len(model_dimensions) == 1 + def problem_list(self, parameters: ParameterGroup = None) -> list[str]: """ Returns a list with all problems in the model and missing parameters if specified. diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index 4676b1abe..c28c34c50 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -1,3 +1,5 @@ +from math import inf +from math import nan from typing import Dict from typing import List from typing import Tuple @@ -12,6 +14,9 @@ from glotaran.model import model_item from glotaran.model.clp_penalties import EqualAreaPenalty from glotaran.model.constraint import Constraint +from glotaran.model.constraint import OnlyConstraint +from glotaran.model.constraint import ZeroConstraint +from glotaran.model.interval_property import IntervalProperty from glotaran.model.model import Model from glotaran.model.relation import Relation from glotaran.model.weight import Weight @@ -257,13 +262,13 @@ def test_model_misc(test_model: Model): def test_model_validity(test_model: Model, model_error: Model, parameter: ParameterGroup): - print(test_model.test_item1["t1"]) - print(test_model.problem_list()) - print(test_model.problem_list(parameter)) + print(test_model.test_item1["t1"]) # noqa T001 + print(test_model.problem_list()) # noqa T001 + print(test_model.problem_list(parameter)) # noqa T001 assert test_model.valid() assert test_model.valid(parameter) - print(model_error.problem_list()) - print(model_error.problem_list(parameter)) + print(model_error.problem_list()) # noqa T001 + print(model_error.problem_list(parameter)) # noqa T001 assert not model_error.valid() assert len(model_error.problem_list()) == 5 assert not model_error.valid(parameter) @@ -324,7 +329,7 @@ def test_fill(test_model: Model, parameter: ParameterGroup): assert dataset.get_model_dimension() == "model" assert dataset.get_global_dimension() == "global" - assert not dataset.global_model() + assert not dataset.has_global_model() dataset = test_model.dataset.get("dataset2").fill(test_model, parameter) assert [cmplx.label for cmplx in dataset.megacomplex] == ["m2"] @@ -332,7 +337,7 @@ def test_fill(test_model: Model, parameter: ParameterGroup): assert dataset.get_model_dimension() == "model2" assert dataset.get_global_dimension() == "model" - assert dataset.global_model() + assert dataset.has_global_model() assert [cmplx.label for cmplx in dataset.global_megacomplex] == ["m1"] t = test_model.test_item1.get("t1").fill(test_model, parameter) @@ -368,3 +373,32 @@ def test_model_ipython_rendering(test_model: Model): assert "text/markdown" in rendered_markdown_return assert rendered_markdown_return["text/markdown"].startswith("# Model") + + +def test_interval_property(): + ip1 = IntervalProperty.from_dict({"interval": [[1, 1000]]}) + assert all(ip1.applies(x) for x in (1, 500, 100)) + assert all(not ip1.applies(x) for x in (9999, inf, nan)) + + +def test_zero_constraint(): + zc1 = ZeroConstraint.from_dict({"interval": [[1, 400], [600, 1000]], "target": "s1"}) + assert all(zc1.applies(x) for x in (1, 2, 400, 600, 1000)) + assert all(not zc1.applies(x) for x in (400.01, 500, 599.99, 9999, inf, nan)) + assert zc1.target == "s1" + zc2 = ZeroConstraint.from_dict({"interval": [[600, 700]], "target": "s2"}) + assert all(zc2.applies(x) for x in range(600, 700, 50)) + assert all(not zc2.applies(x) for x in (599.9999, 700.0001)) + assert zc2.target == "s2" + + +def test_only_constraint(): + oc1 = OnlyConstraint.from_dict({"interval": [[1, 400], (600, 1000)], "target": "spectra1"}) + assert all(oc1.applies(x) for x in (400.01, 500, 599.99, 9999, inf)) + assert all(not oc1.applies(x) for x in (1, 400, 600, 1000)) + assert oc1.target == "spectra1" + oc2 = OnlyConstraint.from_dict({"interval": [(600, 700)], "target": "spectra2"}) + assert oc2.applies(599) + assert not oc2.applies(650) + assert oc2.applies(701) + assert oc2.target == "spectra2" diff --git a/glotaran/project/scheme.py b/glotaran/project/scheme.py index bd5df63a0..b1c73c5f5 100644 --- a/glotaran/project/scheme.py +++ b/glotaran/project/scheme.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from dataclasses import dataclass from typing import TYPE_CHECKING @@ -33,7 +34,7 @@ class Scheme: model: Model | str parameters: ParameterGroup | str data: dict[str, xr.DataArray | xr.Dataset | str] - group: bool = True + group: bool | None = None group_tolerance: float = 0.0 non_negative_least_squares: bool = False maximum_number_function_evaluations: int = None @@ -75,6 +76,15 @@ def markdown(self): return MarkdownStr(markdown_str) + def is_grouped(self) -> bool: + """Returns whether the scheme should be grouped.""" + if self.group is not None and not self.group: + return False + is_groupable = self.model.is_groupable(self.parameters, self.data) + if not is_groupable and self.group is not None: + warnings.warn("Cannot group scheme. Continuing ungrouped.") + return is_groupable + def _repr_markdown_(self) -> str: """Special method used by ``ipython`` to render markdown.""" return str(self.markdown()) diff --git a/glotaran/test/test_spectral_decay.py b/glotaran/test/test_spectral_decay.py index d18b5534b..d0f86cecf 100644 --- a/glotaran/test/test_spectral_decay.py +++ b/glotaran/test/test_spectral_decay.py @@ -7,6 +7,50 @@ from glotaran.io import load_parameters from glotaran.project import Scheme +MODEL_3C_NO_IRF = """\ +dataset: + dataset1: + megacomplex: [mc1] + global_megacomplex: [mc2] + initial_concentration: j1 +megacomplex: + mc1: + type: decay + k_matrix: [k1] + mc2: + type: spectral + shape: + s1: sh1 + s2: sh2 + s3: sh3 +shape: + sh1: + type: skewed-gaussian + amplitude: shapes.amps.1 + location: shapes.locs.1 + width: shapes.width.1 + sh2: + type: skewed-gaussian + amplitude: shapes.amps.2 + location: shapes.locs.2 + width: shapes.width.2 + sh3: + type: skewed-gaussian + amplitude: shapes.amps.3 + location: shapes.locs.3 + width: shapes.width.3 +initial_concentration: + j1: + compartments: [s1, s2, s3] + parameters: [j.1, j.1, j.1] +k_matrix: + k1: + matrix: + (s1, s1): "kinetic.1" + (s2, s2): "kinetic.2" + (s3, s3): "kinetic.3" +""" + MODEL_3C_BASE = """\ dataset: dataset1: &dataset1 @@ -75,6 +119,32 @@ (s3, s3): "kinetic.3" """ +PARAMETERS_3C_NO_IRF = """\ +j: + - ["1", 1, {"vary": False, "non-negative": False}] + - ["0", 0, {"vary": False, "non-negative": False}] +shapes: + amps: [7, 3, 30, {"vary": False}] + locs: [620, 670, 720, {"vary": False}] + width: [10, 30, 50, {"vary": False}] +""" + +PARAMETERS_3C_NO_IRF_WANTED = f"""\ +kinetic: + - ["1", 301e-3] + - ["2", 502e-4] + - ["3", 705e-5] +{PARAMETERS_3C_NO_IRF} +""" + +PARAMETERS_3C_NO_IRF_INITIAL = f"""\ +kinetic: + - ["1", 300e-3] + - ["2", 500e-4] + - ["3", 700e-5] +{PARAMETERS_3C_NO_IRF} +""" + PARAMETERS_3C_BASE = """\ irf: - ["center", 1.3] @@ -134,6 +204,15 @@ """ +class ThreeComponentNoIrf: + model = load_model(MODEL_3C_NO_IRF, format_name="yml_str") + initial_parameters = load_parameters(PARAMETERS_3C_NO_IRF_INITIAL, format_name="yml_str") + wanted_parameters = load_parameters(PARAMETERS_3C_NO_IRF_WANTED, format_name="yml_str") + time = np.arange(0, 100, 1.5) + spectral = np.arange(600, 750, 10) + axis = {"time": time, "spectral": spectral} + + class ThreeComponentParallel: model = load_model(MODEL_3C_PARALLEL, format_name="yml_str") initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_PARALLEL, format_name="yml_str") @@ -155,6 +234,7 @@ class ThreeComponentSequential: @pytest.mark.parametrize( "suite", [ + ThreeComponentNoIrf, ThreeComponentParallel, ThreeComponentSequential, ], @@ -163,19 +243,19 @@ class ThreeComponentSequential: def test_kinetic_model(suite, nnls): model = suite.model - print(model.validate()) + print(model.validate()) # noqa T001 assert model.valid() wanted_parameters = suite.wanted_parameters - print(model.validate(wanted_parameters)) - print(wanted_parameters) + print(model.validate(wanted_parameters)) # noqa T001 + print(wanted_parameters) # noqa T001 assert model.valid(wanted_parameters) initial_parameters = suite.initial_parameters - print(model.validate(initial_parameters)) + print(model.validate(initial_parameters)) # noqa T001 assert model.valid(initial_parameters) - print(model.markdown(wanted_parameters)) + print(model.markdown(wanted_parameters)) # noqa T001 dataset = simulate(model, "dataset1", wanted_parameters, suite.axis) @@ -189,22 +269,32 @@ def test_kinetic_model(suite, nnls): data=data, maximum_number_function_evaluations=20, non_negative_least_squares=nnls, + group=False, ) result = optimize(scheme) - print(result.optimized_parameters) + print(result.optimized_parameters) # noqa T001 for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, wanted_parameters.get(label).value) resultdata = result.data["dataset1"] - print(resultdata) + print(resultdata) # noqa T001 assert np.array_equal(dataset["time"], resultdata["time"]) assert np.array_equal(dataset["spectral"], resultdata["spectral"]) assert dataset.data.shape == resultdata.data.shape assert dataset.data.shape == resultdata.fitted_data.shape - assert np.allclose(dataset.data, resultdata.fitted_data, rtol=1e-2) + assert np.allclose(dataset.data, resultdata.fitted_data, rtol=1e-1) + + assert "species_spectra" in resultdata + spectra = resultdata.species_spectra + assert "spectral_species" in spectra.coords + assert "spectral" in spectra.coords + assert spectra.shape == (suite.axis["spectral"].size, 3) - assert "species_associated_spectra" in resultdata - assert "decay_associated_spectra" in resultdata + assert "species_concentration" in resultdata + concentration = resultdata.species_concentration + assert "species" in concentration.coords + assert "time" in concentration.coords + assert concentration.shape == (suite.axis["time"].size, 3) diff --git a/glotaran/test/test_spectral_decay_full_model.py b/glotaran/test/test_spectral_decay_full_model.py new file mode 100644 index 000000000..8f2e7e532 --- /dev/null +++ b/glotaran/test/test_spectral_decay_full_model.py @@ -0,0 +1,224 @@ +import numpy as np +import pytest + +from glotaran.analysis.optimize import optimize +from glotaran.analysis.simulation import simulate +from glotaran.io import load_model +from glotaran.io import load_parameters +from glotaran.project import Scheme + +MODEL_3C_BASE = """\ +dataset: + dataset1: + megacomplex: [mc1] + global_megacomplex: [mc2] + initial_concentration: j1 + irf: irf1 + dataset2: + megacomplex: [mc2] + global_megacomplex: [mc1] + initial_concentration: j1 + irf: irf1 + dataset3: + megacomplex: [mc1] + initial_concentration: j1 + irf: irf1 + dataset4: + megacomplex: [mc2] +megacomplex: + mc1: + type: decay + k_matrix: [k1] + mc2: + type: spectral + shape: + s1: sh1 + s2: sh2 + s3: sh3 +irf: + irf1: + type: spectral-multi-gaussian + center: [irf.center] + width: [irf.width] +shape: + sh1: + type: skewed-gaussian + amplitude: shapes.amps.1 + location: shapes.locs.1 + width: shapes.width.1 + sh2: + type: skewed-gaussian + amplitude: shapes.amps.2 + location: shapes.locs.2 + width: shapes.width.2 + sh3: + type: skewed-gaussian + amplitude: shapes.amps.3 + location: shapes.locs.3 + width: shapes.width.3 +""" + +MODEL_3C_PARALLEL = f"""\ +{MODEL_3C_BASE} +initial_concentration: + j1: + compartments: [s1, s2, s3] + parameters: [j.1, j.1, j.1] +k_matrix: + k1: + matrix: + (s1, s1): "kinetic.1" + (s2, s2): "kinetic.2" + (s3, s3): "kinetic.3" +""" + +MODEL_3C_SEQUENTIAL = f"""\ +{MODEL_3C_BASE} +initial_concentration: + j1: + compartments: [s1, s2, s3] + parameters: [j.1, j.0, j.0] +k_matrix: + k1: + matrix: + (s2, s1): "kinetic.1" + (s3, s2): "kinetic.2" + (s3, s3): "kinetic.3" +""" + +PARAMETERS_3C_BASE = """\ +irf: + - ["center", 1.3] + - ["width", 7.8] +j: + - ["1", 1, {"vary": False, "non-negative": False}] + - ["0", 0, {"vary": False, "non-negative": False}] +""" + +PARAMETERS_3C_BASE_PARALLEL = f"""\ +{PARAMETERS_3C_BASE} +shapes: + amps: [7, 3, 30, {{"vary": False}}] + locs: [620, 670, 720, {{"vary": False}}] + width: [10, 30, 50, {{"vary": False}}] +""" + +PARAMETERS_3C_BASE_SEQUENTIAL = f"""\ +{PARAMETERS_3C_BASE} +shapes: + amps: + - 9 + - 7 + - 5 + - {{"vary": True, min: 0, max: 10}} + locs: [610, 670, 730, {{"vary": False}}] + width: [15, 25, 10, {{"vary": False}}] +""" + +PARAMETERS_3C_PARALLEL_WANTED = f"""\ +kinetic: + - ["1", 301e-3] + - ["2", 502e-4] + - ["3", 705e-5] +{PARAMETERS_3C_BASE_PARALLEL} +""" + +PARAMETERS_3C_INITIAL_PARALLEL = f"""\ +kinetic: + - ["1", 300e-3, {{non-negative: true}}] + - ["2", 500e-4, {{non-negative: true}}] + - ["3", 700e-5, {{non-negative: true}}] +{PARAMETERS_3C_BASE_PARALLEL} +""" + +PARAMETERS_3C_SIM_SEQUENTIAL = f"""\ +kinetic: + - ["1", 501e-3, {{non-negative: true}}] + - ["2", 202e-4, {{non-negative: true}}] + - ["3", 105e-5, {{non-negative: true}}] +{PARAMETERS_3C_BASE_SEQUENTIAL} +""" + +PARAMETERS_3C_INITIAL_SEQUENTIAL = f"""\ +kinetic: + - ["1", 500e-3] + - ["2", 200e-4] + - ["3", 100e-5] + - {{"non-negative": True}} +{PARAMETERS_3C_BASE_SEQUENTIAL} +""" + + +class ThreeComponentParallel: + model = load_model(MODEL_3C_PARALLEL, format_name="yml_str") + initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_PARALLEL, format_name="yml_str") + wanted_parameters = load_parameters(PARAMETERS_3C_PARALLEL_WANTED, format_name="yml_str") + time = np.arange(-10, 100, 1.5) + spectral = np.arange(600, 750, 10) + axis = {"time": time, "spectral": spectral} + + +class ThreeComponentSequential: + model = load_model(MODEL_3C_SEQUENTIAL, format_name="yml_str") + initial_parameters = load_parameters(PARAMETERS_3C_INITIAL_SEQUENTIAL, format_name="yml_str") + wanted_parameters = load_parameters(PARAMETERS_3C_SIM_SEQUENTIAL, format_name="yml_str") + time = np.arange(-10, 50, 1.0) + spectral = np.arange(600, 750, 5.0) + axis = {"time": time, "spectral": spectral} + + +@pytest.mark.parametrize( + "suite", + [ + ThreeComponentParallel, + ThreeComponentSequential, + ], +) +@pytest.mark.parametrize("nnls", [True, False]) +def test_kinetic_model(suite, nnls): + + model = suite.model + print(model.validate()) # noqa T001 + assert model.valid() + + wanted_parameters = suite.wanted_parameters + print(model.validate(wanted_parameters)) # noqa T001 + print(wanted_parameters) # noqa T001 + assert model.valid(wanted_parameters) + + initial_parameters = suite.initial_parameters + print(model.validate(initial_parameters)) # noqa T001 + assert model.valid(initial_parameters) + + print(model.markdown(wanted_parameters)) # noqa T001 + + dataset = simulate(model, "dataset1", wanted_parameters, suite.axis) + + assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) + + data = {f"dataset{i}": dataset for i in range(1, 5)} + + scheme = Scheme( + model=model, + parameters=initial_parameters, + data=data, + maximum_number_function_evaluations=20, + non_negative_least_squares=nnls, + group=False, + ) + result = optimize(scheme) + print(result.optimized_parameters) # noqa T001 + + for label, param in result.optimized_parameters.all(): + print(label, param.value, wanted_parameters.get(label).value) # noqa T001 + assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + + resultdata = result.data["dataset1"] + + print(resultdata) # noqa T001 + + assert np.array_equal(dataset["time"], resultdata["time"]) + assert np.array_equal(dataset["spectral"], resultdata["spectral"]) + assert dataset.data.shape == resultdata.data.shape + assert dataset.data.shape == resultdata.fitted_data.shape + assert np.allclose(dataset.data, resultdata.fitted_data, rtol=1e-2) diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index 000c8f418..a00c9b853 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -194,7 +194,7 @@ def test_equal_area_penalties(debug=False): model_sim = SpectralDecayModel.from_dict(mspec_sim) model_wp = SpectralDecayModel.from_dict(mspec_fit_wp) model_np = SpectralDecayModel.from_dict(mspec_fit_np) - print(model_np) + print(model_np) # noqa T001 # %% Parameter specification (pspec) @@ -214,9 +214,9 @@ def test_equal_area_penalties(debug=False): param_np = ParameterGroup.from_dict(pspec_np) # %% Print models with parameters - print(model_sim.markdown(param_sim)) - print(model_wp.markdown(param_wp)) - print(model_np.markdown(param_np)) + print(model_sim.markdown(param_sim)) # noqa T001 + print(model_wp.markdown(param_wp)) # noqa T001 + print(model_np.markdown(param_np)) # noqa T001 # %% simulated_data = simulate( @@ -244,7 +244,7 @@ def test_equal_area_penalties(debug=False): maximum_number_function_evaluations=optim_spec.max_nfev, ) result_np = optimize(scheme_np) - print(result_np) + print(result_np) # noqa T001 # %% Optimizing model with penalty fixed inputs (wp_ifix) scheme_wp = Scheme( @@ -255,7 +255,7 @@ def test_equal_area_penalties(debug=False): maximum_number_function_evaluations=optim_spec.max_nfev, ) result_wp = optimize(scheme_wp) - print(result_wp) + print(result_wp) # noqa T001 if debug: # %% Plot results @@ -268,10 +268,10 @@ def test_equal_area_penalties(debug=False): plt.show() # %% Test calculation - print(result_wp.data["dataset1"]) + print(result_wp.data["dataset1"]) # noqa T001 area1_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s1")) area2_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s2")) - print("area_np", area1_np, area2_np) + print("area_np", area1_np, area2_np) # noqa T001 assert not np.isclose(area1_np, area2_np) area1_wp = np.sum(result_wp.data["dataset1"].species_associated_spectra.sel(species="s1")) @@ -281,7 +281,7 @@ def test_equal_area_penalties(debug=False): input_ratio = result_wp.optimized_parameters.get("i.1") / result_wp.optimized_parameters.get( "i.2" ) - print("input", input_ratio) + print("input", input_ratio) # noqa T001 assert np.isclose(input_ratio, 1.5038858115) From 677a992dfa88980b16fe3a79202f381730dcc4f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Sat, 7 Aug 2021 16:44:36 +0200 Subject: [PATCH 07/29] Small fix for baseline megacomplex (#762) * Small fix for baseline megacomplex * Save baseline as name 'baseline' --- glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py index 9dbadbe96..7e1a35bb4 100644 --- a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -32,4 +32,4 @@ def finalize_data( as_global: bool = False, ): if not is_full_model: - dataset[f"{dataset_model.label}_baseline"] = dataset.clp.sel(clp_label="baseline") + dataset["baseline"] = dataset.clp.sel(clp_label=f"{dataset_model.label}_baseline") From 8adfe8f74a8de2ccbf6539dc951a8a96f7980428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Thu, 12 Aug 2021 20:33:18 +0200 Subject: [PATCH 08/29] Refactor/spectral model (#763) * Added SpectralShapeGaussian. * Replaced property energy_spectrum of SpectralMegacomplex with invert and axis_scale. * Made invert and axis_scale dataset properties. * Added guard and meaningful exception message for spectral skewness * Fixed scaling * Made SpectralShapeSkewedGaussian decendent of SpectralShapeGaussian and added fallback for skewness == 0. * Move sanatize.py to utils module and correct typo sanatize -> sanitize (incl rename to sanitize.py) * Move regex patterns to seperate module in utils * Add sanity_scientific_notation_conversion Convert scientific notation string (e.g. 1E7) to proper floats * Fixed spectral megacomplex test parameters and added test for inverted axis * Removed model_item._from_list * Added convenience in model_item.from_dict for automatically convert float or int typed properties which are parsed as strings. * Made amplitude of shape optional Co-authored-by: Joris Snellenburg --- .../builtin/io/yml/test/test_model_parser.py | 21 ++- .../builtin/io/yml/test/test_model_spec.yml | 4 +- glotaran/builtin/io/yml/yml.py | 4 +- .../builtin/megacomplexes/spectral/shape.py | 75 ++++------ .../spectral/spectral_megacomplex.py | 11 +- .../spectral/test/test_spectral_model.py | 114 +++++++++++--- glotaran/examples/sequential.py | 6 +- glotaran/examples/test/test_example.py | 7 + glotaran/model/item.py | 37 +---- glotaran/model/property.py | 8 +- glotaran/parameter/parameter.py | 16 +- glotaran/test/test_spectral_decay.py | 12 +- .../test/test_spectral_decay_full_model.py | 6 +- glotaran/test/test_spectral_penalties.py | 4 +- glotaran/utils/regex.py | 16 ++ .../io/yml/sanatize.py => utils/sanitize.py} | 141 +++++++++++++----- .../test/test_sanitize.py} | 4 +- 17 files changed, 305 insertions(+), 181 deletions(-) create mode 100644 glotaran/examples/test/test_example.py create mode 100644 glotaran/utils/regex.py rename glotaran/{builtin/io/yml/sanatize.py => utils/sanitize.py} (65%) rename glotaran/{builtin/io/yml/test/test_util.py => utils/test/test_sanitize.py} (93%) diff --git a/glotaran/builtin/io/yml/test/test_model_parser.py b/glotaran/builtin/io/yml/test/test_model_parser.py index b19a48dab..8fa942b55 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -8,7 +8,7 @@ from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian -from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeSkewedGaussian +from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeGaussian from glotaran.io import load_model from glotaran.model import DatasetModel from glotaran.model import Model @@ -25,7 +25,7 @@ def model(): spec_path = join(THIS_DIR, "test_model_spec.yml") m = load_model(spec_path) - print(m.markdown()) + print(m.markdown()) # noqa return m @@ -48,9 +48,20 @@ def test_dataset(model): assert dataset.irf == "irf1" assert dataset.scale == 1 + assert "dataset2" in model.dataset + dataset = model.dataset["dataset2"] + assert isinstance(dataset, DatasetModel) + assert dataset.label == "dataset2" + assert dataset.megacomplex == ["cmplx2"] + assert dataset.initial_concentration == "inputD2" + assert dataset.irf == "irf2" + assert dataset.scale == 2 + assert dataset.spectral_axis_scale == 1e7 + assert dataset.spectral_axis_inverted + def test_constraints(model): - print(model.constraints) + print(model.constraints) # noqa assert len(model.constraints) == 2 zero = model.constraints[0] @@ -77,7 +88,7 @@ def test_penalties(model): def test_relations(model): - print(model.relations) + print(model.relations) # noqa assert len(model.relations) == 1 rel = model.relations[0] @@ -154,7 +165,7 @@ def test_shapes(model): assert "shape1" in model.shape shape = model.shape["shape1"] - assert isinstance(shape, SpectralShapeSkewedGaussian) + assert isinstance(shape, SpectralShapeGaussian) assert shape.amplitude.full_label == "shape.1" assert shape.location.full_label == "shape.2" assert shape.width.full_label == "shape.3" diff --git a/glotaran/builtin/io/yml/test/test_model_spec.yml b/glotaran/builtin/io/yml/test/test_model_spec.yml index 499e61b8f..cf5dfce21 100644 --- a/glotaran/builtin/io/yml/test/test_model_spec.yml +++ b/glotaran/builtin/io/yml/test/test_model_spec.yml @@ -12,6 +12,8 @@ dataset: initial_concentration: inputD2 irf: irf2 scale: 2 + spectral_axis_scale: 1e7 + spectral_axis_inverted: true irf: irf1: @@ -54,7 +56,7 @@ k_matrix: shape: shape1: - type: "skewed-gaussian" + type: "gaussian" amplitude: shape.1 location: shape.2 width: shape.3 diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index 63fe29884..f0b0a42df 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -7,8 +7,6 @@ import yaml -from glotaran.builtin.io.yml.sanatize import check_deprecations -from glotaran.builtin.io.yml.sanatize import sanitize_yaml from glotaran.io import ProjectIoInterface from glotaran.io import load_dataset from glotaran.io import load_model @@ -21,6 +19,8 @@ from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme +from glotaran.utils.sanitize import check_deprecations +from glotaran.utils.sanitize import sanitize_yaml if TYPE_CHECKING: from glotaran.project import Result diff --git a/glotaran/builtin/megacomplexes/spectral/shape.py b/glotaran/builtin/megacomplexes/spectral/shape.py index 7823590b0..f727b9ea6 100644 --- a/glotaran/builtin/megacomplexes/spectral/shape.py +++ b/glotaran/builtin/megacomplexes/spectral/shape.py @@ -9,52 +9,16 @@ @model_item( properties={ - "amplitude": Parameter, + "amplitude": {"type": Parameter, "allow_none": True}, "location": Parameter, "width": Parameter, - "skewness": {"type": Parameter, "allow_none": True}, }, has_type=True, ) -class SpectralShapeSkewedGaussian: - """A (skewed) Gaussian spectral shape""" +class SpectralShapeGaussian: + """A Gaussian spectral shape""" def calculate(self, axis: np.ndarray) -> np.ndarray: - r"""Calculate a (skewed) Gaussian shape for a given ``axis``. - - If a non-zero ``skewness`` parameter was added - :func:`calculate_skewed_gaussian` will be used. - Otherwise it will use :func:`calculate_gaussian`. - - Parameters - ---------- - axis: np.ndarray - The axis to calculate the shape for. - - Returns - ------- - shape: numpy.ndarray - A Gaussian shape. - - See Also - -------- - calculate_gaussian - calculate_skewed_gaussian - - Note - ---- - Internally ``axis`` is converted from :math:`\mbox{nm}` to - :math:`1/\mbox{cm}`, thus ``location`` and ``width`` also need to - be provided in :math:`1/\mbox{cm}` (``1e7/value_in_nm``). - - """ - return ( - self.calculate_skewed_gaussian(axis) - if self.skewness is not None and not np.allclose(self.skewness, 0) - else self.calculate_gaussian(axis) - ) - - def calculate_gaussian(self, axis: np.ndarray) -> np.ndarray: r"""Calculate a normal Gaussian shape for a given ``axis``. The following equation is used for the calculation: @@ -91,11 +55,22 @@ def calculate_gaussian(self, axis: np.ndarray) -> np.ndarray: np.ndarray An array representing a Gaussian shape. """ - return self.amplitude * np.exp( - -np.log(2) * np.square(2 * (axis - self.location) / self.width) - ) + shape = np.exp(-np.log(2) * np.square(2 * (axis - self.location) / self.width)) + if self.amplitude is not None: + shape *= self.amplitude + return shape - def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: + +@model_item( + properties={ + "skewness": Parameter, + }, + has_type=True, +) +class SpectralShapeSkewedGaussian(SpectralShapeGaussian): + """A skewed Gaussian spectral shape""" + + def calculate(self, axis: np.ndarray) -> np.ndarray: r"""Calculate the skewed Gaussian shape for ``axis``. The following equation is used for the calculation: @@ -134,7 +109,7 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: Note that in the limit of skewness parameter :math:`b` equal to zero :math:`f(x, x_0, A, \Delta, b)` simplifies to a normal gaussian (since :math:`\lim_{b \to 0} \frac{\ln(1+bx)}{b}=x`), - see the definition in :func:`calculate_gaussian`. + see the definition in :func:`SpectralShapeGaussian.calculate`. Parameters ---------- @@ -147,14 +122,17 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: np.ndarray An array representing a skewed Gaussian shape. """ + if np.allclose(self.skewness, 0): + return super().calculate(axis) log_args = 1 + (2 * self.skewness * (axis - self.location) / self.width) - result = np.zeros(log_args.shape) + shape = np.zeros(log_args.shape) valid_arg_mask = np.where(log_args > 0) - result[valid_arg_mask] = self.amplitude * np.exp( + shape[valid_arg_mask] = np.exp( -np.log(2) * np.square(np.log(log_args[valid_arg_mask]) / self.skewness) ) - - return result + if self.amplitude is not None: + shape *= self.amplitude + return shape @model_item(properties={}, has_type=True) @@ -201,6 +179,7 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: @model_item_typed( types={ + "gaussian": SpectralShapeGaussian, "skewed-gaussian": SpectralShapeSkewedGaussian, "one": SpectralShapeOne, "zero": SpectralShapeZero, diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index b25ec7151..ab2f70900 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -14,7 +14,10 @@ @megacomplex( dimension="spectral", - properties={"energy_spectrum": {"type": bool, "default": False}}, + dataset_properties={ + "spectral_axis_inverted": {"type": bool, "default": False}, + "spectral_axis_scale": {"type": float, "default": 1}, + }, model_items={ "shape": Dict[str, SpectralShape], }, @@ -35,8 +38,10 @@ def calculate_matrix( compartments.append(compartment) model_axis = dataset_model.get_model_axis() - if self.energy_spectrum: - model_axis = 1e7 / model_axis + if dataset_model.spectral_axis_inverted: + model_axis = dataset_model.spectral_axis_scale / model_axis + elif dataset_model.spectral_axis_scale != 1: + model_axis = model_axis * dataset_model.spectral_axis_scale dim1 = model_axis.size dim2 = len(self.shape) diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index 9e163292d..ce559a31d 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -23,7 +23,7 @@ def from_dict(cls, model_dict): ) -class OneCompartmentModel: +class OneCompartmentModelInvertedAxis: decay_model = DecayModel.from_dict( { "initial_concentration": { @@ -59,7 +59,7 @@ class OneCompartmentModel: }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "1", "location": "2", "width": "3", @@ -68,12 +68,76 @@ class OneCompartmentModel: "dataset": { "dataset1": { "megacomplex": ["mc1"], + "spectral_axis_scale": 1e7, + "spectral_axis_inverted": True, + }, + }, + } + ) + + spectral_parameters = ParameterGroup.from_list([7, 1e7 / 10000, 800, -1]) + + time = np.arange(-10, 50, 1.5) + spectral = np.arange(5000, 15000, 20) + axis = {"time": time, "spectral": spectral} + + decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model.overwrite_global_dimension("spectral") + decay_dataset_model.set_coordinates(axis) + matrix = calculate_matrix(decay_dataset_model, {}) + decay_compartments = matrix.clp_labels + clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) + + +class OneCompartmentModelNegativeSkew: + decay_model = DecayModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1"], "parameters": ["2"]}, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s1", "s1"): "1", + } + } + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "megacomplex": ["mc1"], }, }, } ) - spectral_parameters = ParameterGroup.from_list([7, 20000, 800]) + decay_parameters = ParameterGroup.from_list( + [101e-4, [1, {"vary": False, "non-negative": False}]] + ) + + spectral_model = SpectralModel.from_dict( + { + "megacomplex": { + "mc1": {"shape": {"s1": "sh1"}}, + }, + "shape": { + "sh1": { + "type": "skewed-gaussian", + "location": "1", + "width": "2", + "skewness": "3", + } + }, + "dataset": { + "dataset1": {"megacomplex": ["mc1"], "spectral_axis_scale": 2}, + }, + } + ) + + spectral_parameters = ParameterGroup.from_list([1000, 80, -1]) time = np.arange(-10, 50, 1.5) spectral = np.arange(400, 600, 5) @@ -87,6 +151,14 @@ class OneCompartmentModel: clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) +class OneCompartmentModelPositivSkew(OneCompartmentModelNegativeSkew): + spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 1]) + + +class OneCompartmentModelZeroSkew(OneCompartmentModelNegativeSkew): + spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 0]) + + class ThreeCompartmentModel: decay_model = DecayModel.from_dict( { @@ -131,23 +203,22 @@ class ThreeCompartmentModel: }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "1", "location": "2", "width": "3", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "4", "location": "5", "width": "6", }, "sh3": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "7", "location": "8", "width": "9", - "skewness": "10", }, }, "dataset": { @@ -161,15 +232,14 @@ class ThreeCompartmentModel: spectral_parameters = ParameterGroup.from_list( [ 7, - 20000, - 800, + 450, + 80, 20, - 22000, - 500, + 550, + 50, + 10, + 580, 10, - 18000, - 650, - 0.1, ] ) @@ -188,26 +258,28 @@ class ThreeCompartmentModel: @pytest.mark.parametrize( "suite", [ - OneCompartmentModel, + OneCompartmentModelNegativeSkew, + OneCompartmentModelPositivSkew, + OneCompartmentModelZeroSkew, ThreeCompartmentModel, ], ) def test_spectral_model(suite): model = suite.spectral_model - print(model.validate()) + print(model.validate()) # noqa assert model.valid() wanted_parameters = suite.spectral_parameters - print(model.validate(wanted_parameters)) - print(wanted_parameters) + print(model.validate(wanted_parameters)) # noqa + print(wanted_parameters) # noqa assert model.valid(wanted_parameters) initial_parameters = suite.spectral_parameters - print(model.validate(initial_parameters)) + print(model.validate(initial_parameters)) # noqa assert model.valid(initial_parameters) - print(model.markdown(initial_parameters)) + print(model.markdown(initial_parameters)) # noqa dataset = simulate(model, "dataset1", wanted_parameters, suite.axis, suite.clp) @@ -222,7 +294,7 @@ def test_spectral_model(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme) - print(result.optimized_parameters) + print(result.optimized_parameters) # noqa for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) diff --git a/glotaran/examples/sequential.py b/glotaran/examples/sequential.py index 328102d5b..d0f0f635e 100644 --- a/glotaran/examples/sequential.py +++ b/glotaran/examples/sequential.py @@ -39,19 +39,19 @@ }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", }, "sh3": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.3", "location": "shapes.locs.3", "width": "shapes.width.3", diff --git a/glotaran/examples/test/test_example.py b/glotaran/examples/test/test_example.py new file mode 100644 index 000000000..7bd6e74a4 --- /dev/null +++ b/glotaran/examples/test/test_example.py @@ -0,0 +1,7 @@ +import xarray as xr + +from glotaran.examples.sequential import dataset + + +def test_dataset(): + assert isinstance(dataset, xr.Dataset) diff --git a/glotaran/model/item.py b/glotaran/model/item.py index 82d1c9519..21dd6188f 100644 --- a/glotaran/model/item.py +++ b/glotaran/model/item.py @@ -109,9 +109,6 @@ def decorator(cls): from_dict = _create_from_dict_func(cls) setattr(cls, "from_dict", from_dict) - from_list = _create_from_list_func(cls) - setattr(cls, "from_list", from_list) - validate = _create_validation_func(cls) setattr(cls, "validate", validate) @@ -233,7 +230,13 @@ def from_dict(ncls, values: dict) -> cls: for name in ncls._glotaran_properties: if name in values: - setattr(item, name, values[name]) + value = values[name] + prop = getattr(item.__class__, name) + if prop.property_type == float: + value = float(value) + elif prop.property_type == int: + value = int(value) + setattr(item, name, value) elif not getattr(ncls, name).allow_none and getattr(item, name) is None: raise ValueError(f"Missing Property '{name}' For Item '{ncls.__name__}'") @@ -242,32 +245,6 @@ def from_dict(ncls, values: dict) -> cls: return from_dict -def _create_from_list_func(cls): - @classmethod - @wrap_func_as_method(cls) - def from_list(ncls, values: list) -> cls: - f"""Creates an instance of {cls.__name__} from a list of values. Intended only for internal use. - - Parameters - ---------- - values : - A list of values. - """ - item = ncls() - if len(values) != len(ncls._glotaran_properties): - raise ValueError( - f"To few or much parameters for '{ncls.__name__}'" - f"\nGot: {values}\nWant: {ncls._glotaran_properties}" - ) - - for i, name in enumerate(ncls._glotaran_properties): - setattr(item, name, values[i]) - - return item - - return from_list - - def _create_validation_func(cls): @wrap_func_as_method(cls) def validate(self, model: Model, parameters: ParameterGroup | None = None) -> list[str]: diff --git a/glotaran/model/property.py b/glotaran/model/property.py index 92518badc..41c6d4c0f 100644 --- a/glotaran/model/property.py +++ b/glotaran/model/property.py @@ -13,10 +13,10 @@ def __init__(self, cls, name, prop_type, doc, default, allow_none): self._allow_none = allow_none self._determine_if_parameter(prop_type) - set_type = prop_type if not self._is_parameter else typing.Union[str, prop_type] + self._type = prop_type if not self._is_parameter else typing.Union[str, prop_type] @wrap_func_as_method(cls, name=name) - def setter(that_self, value: set_type): + def setter(that_self, value: self._type): if value is None and not self._allow_none: raise Exception( f"Property '{name}' of '{cls.__name__}' is not allowed to set to None." @@ -44,6 +44,10 @@ def getter(that_self) -> prop_type: def allow_none(self) -> bool: return self._allow_none + @property + def property_type(self) -> typing.Type: + return self._type + def validate(self, value, model, parameters=None) -> typing.List[str]: if value is None and self.allow_none: diff --git a/glotaran/parameter/parameter.py b/glotaran/parameter/parameter.py index 3a0ea0f2b..7027da627 100644 --- a/glotaran/parameter/parameter.py +++ b/glotaran/parameter/parameter.py @@ -9,6 +9,8 @@ import numpy as np from numpy.typing._array_like import _SupportsArray +from glotaran.utils.sanitize import sanitize_parameter_list + if TYPE_CHECKING: from typing import Any @@ -113,7 +115,7 @@ def from_list_or_value( param.value = value else: - values = _sanatize_parameter_list(value) + values = sanitize_parameter_list(value) param.label = _retrieve_from_list_by_type(values, str, label) param.value = float(_retrieve_from_list_by_type(values, (int, float), 0)) options = _retrieve_from_list_by_type(values, dict, None) @@ -485,18 +487,6 @@ def _log_value(value: float): return np.log(value) -# A reexp for ONLY matching scientific -_match_scientific = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)") - - -def _sanatize_parameter_list(li: list) -> list: - for i, value in enumerate(li): - if isinstance(value, str) and _match_scientific.match(value): - li[i] = float(value) - - return li - - def _retrieve_from_list_by_type(li: list, t: type | tuple[type, ...], default: Any): tmp = list(filter(lambda x: isinstance(x, t), li)) if not tmp: diff --git a/glotaran/test/test_spectral_decay.py b/glotaran/test/test_spectral_decay.py index d0f86cecf..a23ce515d 100644 --- a/glotaran/test/test_spectral_decay.py +++ b/glotaran/test/test_spectral_decay.py @@ -25,17 +25,17 @@ s3: sh3 shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 @@ -75,17 +75,17 @@ width: [irf.width] shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 diff --git a/glotaran/test/test_spectral_decay_full_model.py b/glotaran/test/test_spectral_decay_full_model.py index 8f2e7e532..28d441db9 100644 --- a/glotaran/test/test_spectral_decay_full_model.py +++ b/glotaran/test/test_spectral_decay_full_model.py @@ -42,17 +42,17 @@ width: [irf.width] shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index a00c9b853..5b055aedc 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -127,13 +127,13 @@ def test_equal_area_penalties(debug=False): shape = { "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", diff --git a/glotaran/utils/regex.py b/glotaran/utils/regex.py new file mode 100644 index 000000000..0f701bad9 --- /dev/null +++ b/glotaran/utils/regex.py @@ -0,0 +1,16 @@ +"""Glotaran module with regular expression patterns and functions.""" +import re + + +class RegexPattern: + """An 'Enum' of (compiled) regular expression patterns (rp).""" + + # tuple = re.compile(r"(\(.*?,.*?\))") + elements_in_string_of_list: re.Pattern = re.compile(r"(\(.+?\)|[-+.\d]+)") + group: re.Pattern = re.compile(r"(\(.+?\))") + list_with_tuples: re.Pattern = re.compile(r"(\[.+\(.+\).+\])") + word: re.Pattern = re.compile(r"[\w]+") + number_scientific: re.Pattern = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)") + number: re.Pattern = re.compile(r"[\d.+-]+") + tuple_word: re.Pattern = re.compile(r"(\([.\s\w\d]+?[,.\s\w\d]*?\))") + tuple_number: re.Pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))") diff --git a/glotaran/builtin/io/yml/sanatize.py b/glotaran/utils/sanitize.py similarity index 65% rename from glotaran/builtin/io/yml/sanatize.py rename to glotaran/utils/sanitize.py index fe3084656..d26a97d3e 100644 --- a/glotaran/builtin/io/yml/sanatize.py +++ b/glotaran/utils/sanitize.py @@ -1,22 +1,14 @@ -import re -from typing import List -from typing import Tuple -from typing import Union +"""Glotaran module with utilities for sanitation of parsed content.""" +from __future__ import annotations -from glotaran.deprecation import warn_deprecated +from typing import Any -# tuple_pattern = re.compile(r"(\(.*?,.*?\))") -tuple_number_pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))") -number_pattern = re.compile(r"[\d.+-]+") -tuple_name_pattern = re.compile(r"(\([.\s\w\d]+?[,.\s\w\d]*?\))") -name_pattern = re.compile(r"[\w]+") -group_pattern = re.compile(r"(\(.+?\))") -match_list_with_tuples = re.compile(r"(\[.+\(.+\).+\])") -match_elements_in_string_of_list = re.compile(r"(\(.+?\)|[-+.\d]+)") +from glotaran.deprecation import warn_deprecated +from glotaran.utils.regex import RegexPattern as rp -def sanitize_list_with_broken_tuples(mangled_list: List[Union[str, float]]) -> List[str]: - """Sanitize a list with 'broken' tuples +def sanitize_list_with_broken_tuples(mangled_list: list[str | float]) -> list[str]: + """Sanitize a list with 'broken' tuples. A list of broken tuples as returned by yaml when parsing tuples. e.g parsing the list of tuples [(3,100), (4,200)] results in @@ -34,13 +26,12 @@ def sanitize_list_with_broken_tuples(mangled_list: List[Union[str, float]]) -> L A list containing the restores tuples (in string form) which can be converted back to numbered tuples using `list_string_to_tuple` """ - sanitized_string = str(mangled_list).replace("'", "") - return list(match_elements_in_string_of_list.findall(sanitized_string)) + return list(rp.elements_in_string_of_list.findall(sanitized_string)) def sanitize_dict_keys(d: dict) -> dict: - """Sanitize the stringified tuple dict keys in a yaml parsed dict + """Sanitize the stringified tuple dict keys in a yaml parsed dict. Keys representing a tuple, e.g. '(s1, s2)' are converted to a tuple of strings e.g. ('s1', 's2') @@ -59,8 +50,8 @@ def sanitize_dict_keys(d: dict) -> dict: return {} d_new = {} for k, v in d.items() if isinstance(d, dict) else enumerate(d): - if isinstance(d, dict) and isinstance(k, str) and tuple_name_pattern.match(k): - k_new = tuple(map(str, name_pattern.findall(k))) + if isinstance(d, dict) and isinstance(k, str) and rp.tuple_word.match(k): + k_new = tuple(map(str, rp.word.findall(k))) d_new.update({k_new: v}) elif isinstance(d, (dict, list)): new_v = sanitize_dict_keys(v) @@ -69,18 +60,38 @@ def sanitize_dict_keys(d: dict) -> dict: return d_new -def sanitize_dict_values(d: dict): - """Sanitizes a dict with broken tuples inside modifying it in-place +def sanity_scientific_notation_conversion(d: dict[str, Any] | list[Any]): + """Convert scientific notation string values to floats. + + Parameters + ---------- + d : dict[str, Any] | list[Any] + Iterable which should be checked for scientific notation values. + """ + if not isinstance(d, (dict, list)): + return + for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[attr-defined] + if isinstance(v, (list, dict)): + sanity_scientific_notation_conversion(v) + if isinstance(v, str): + d[k] = convert_scientific_to_float(v) + + +def sanitize_dict_values(d: dict[str, Any] | list[Any]): + """Sanitizes a dict with broken tuples inside modifying it in-place. + Broken tuples are tuples that are turned into strings by the yaml parser. This functions calls `sanitize_list_with_broken_tuples` to glue the broken strings together and then calls list_to_tuple to turn the list with tuple strings back to number tuples. - Args: - d (dict): A (complex) dict containing (possibly nested) values of broken tuple strings + Parameters + ---------- + d : dict + A (complex) dict containing (possibly nested) values of broken tuple strings. """ if not isinstance(d, (dict, list)): return - for k, v in d.items() if isinstance(d, dict) else enumerate(d): + for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[attr-defined] if isinstance(v, list): leaf = all(isinstance(el, (str, tuple, float)) for el in v) if leaf: @@ -96,8 +107,8 @@ def sanitize_dict_values(d: dict): def string_to_tuple( tuple_str: str, from_list=False -) -> Union[Tuple[float], Tuple[str], float, str]: - """[summary] +) -> tuple[float, ...] | tuple[str, ...] | float | str: + """Convert a string to a tuple if it matches a tuple pattern. Parameters ---------- @@ -111,22 +122,23 @@ def string_to_tuple( Returns ------- - Union[Tuple[float], Tuple[str], float, str] + tuple[float], tuple[str], float, str Returns the tuple intended by the string """ - - if tuple_number_pattern.match(tuple_str): - return tuple(map(float, number_pattern.findall(tuple_str))) - elif tuple_name_pattern.match(tuple_str): - return tuple(map(str, name_pattern.findall(tuple_str))) - elif from_list and number_pattern.match(tuple_str): + if rp.tuple_number.match(tuple_str): + return tuple(map(float, rp.number.findall(tuple_str))) + elif rp.tuple_word.match(tuple_str): + return tuple(map(str, rp.word.findall(tuple_str))) + elif from_list and rp.number.match(tuple_str): return float(tuple_str) else: return tuple_str -def list_string_to_tuple(a_list: List[str]) -> List[Union[float, str]]: - """Converts a list of strings (representing tuples) to a list of tuples +def list_string_to_tuple( + a_list: list[str], +) -> list[tuple[float, ...] | tuple[str, ...] | float | str]: + """Convert a list of strings (representing tuples) to a list of tuples. Parameters ---------- @@ -138,18 +150,20 @@ def list_string_to_tuple(a_list: List[str]) -> List[Union[float, str]]: List[Union[float, str]] A list of the (numbered) tuples represted by the incoming a_list """ - for i, v in enumerate(a_list): - a_list[i] = string_to_tuple(v, from_list=True) - return a_list + return [string_to_tuple(v, from_list=True) for v in a_list] def sanitize_yaml(d: dict, do_keys: bool = True, do_values: bool = False) -> dict: - """Sanitize a yaml-returned dict for key or (list) values containing tuples + """Sanitize a yaml-returned dict for key or (list) values containing tuples. Parameters ---------- d : dict a dict resulting from parsing a pyglotaran model spec yml file + do_keys : bool + toggle sanitization of dict keys, by default True + do_values : bool + toggle sanitization of dict values, by default False Returns ------- @@ -161,10 +175,57 @@ def sanitize_yaml(d: dict, do_keys: bool = True, do_values: bool = False) -> dic if do_values: # this is only needed to allow for tuple parsing in specification sanitize_dict_values(d) + sanity_scientific_notation_conversion(d) return d +def convert_scientific_to_float(value: str) -> float | str: + """Convert value to float if it matches scientific notation string. + + Parameters + ---------- + value : str + value to convert from string to float if it matches scientific notation + + Returns + ------- + float | string + return float if value was scientific notation string, else turn original value + """ + if rp.number_scientific.match(value): + return float(value) + else: + return value + + +def sanitize_parameter_list(parameter_list: list[str | float]) -> list[str | float]: + """Replace in a list strings matching scientific notation with floats. + + Parameters + ---------- + parameter_list : list + A list of parameters where some elements may be strings like 1E7 + + Returns + ------- + list + A list where strings matching a scientific number have been converted to float + """ + for i, value in enumerate(parameter_list): + if isinstance(value, str): + parameter_list[i] = convert_scientific_to_float(value) + + return parameter_list + + def check_deprecations(spec: dict): + """Check deprecations in a `spec` dict. + + Parameters + ---------- + spec : dict + A specification dictionary + """ if "type" in spec: if spec["type"] == "kinetic-spectrum": warn_deprecated( diff --git a/glotaran/builtin/io/yml/test/test_util.py b/glotaran/utils/test/test_sanitize.py similarity index 93% rename from glotaran/builtin/io/yml/test/test_util.py rename to glotaran/utils/test/test_sanitize.py index 812b00b93..3c7e6b948 100644 --- a/glotaran/builtin/io/yml/test/test_util.py +++ b/glotaran/utils/test/test_sanitize.py @@ -5,12 +5,12 @@ import pytest -from glotaran.builtin.io.yml.sanatize import sanitize_list_with_broken_tuples +from glotaran.utils.sanitize import sanitize_list_with_broken_tuples class MangledListTestData(NamedTuple): input: list[Any] - input_sanitized: list[str] + input_sanitized: list[str] | str output: list[str] From c2bb87b6505a5a85e25759fc3b895c89008c8fec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Fri, 13 Aug 2021 16:27:42 +0200 Subject: [PATCH 09/29] Fix/cli0.5 (#765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various fixes and improvements to the glotaran command line interface. * Changed CLI save plugin to folder * Added outputformat option to CLI * Added basic test for CLI * Rename CLI entrypoint to main and add more CLI tests * 👌 CLI use same default for non_negative_least_squares as scheme * 👌 CLI dedent pluginlist output * 🩹 CLI fixed result outformat accepting none supported formats Co-authored-by: Joris Snellenburg Co-authored-by: s-weigand --- glotaran/cli/__init__.py | 1 + glotaran/cli/commands/optimize.py | 27 +++++++------- glotaran/cli/commands/pluginlist.py | 16 +++++---- glotaran/cli/commands/test/test_util.py | 10 ++++++ glotaran/cli/commands/util.py | 47 ++++++++++++++++++++++--- glotaran/cli/main.py | 29 +++++++-------- glotaran/cli/test/test_cli.py | 43 ++++++++++++++++++++++ glotaran/project/scheme.py | 2 +- setup.cfg | 2 +- 9 files changed, 138 insertions(+), 39 deletions(-) create mode 100644 glotaran/cli/commands/test/test_util.py create mode 100644 glotaran/cli/test/test_cli.py diff --git a/glotaran/cli/__init__.py b/glotaran/cli/__init__.py index e69de29bb..125ef8fd8 100644 --- a/glotaran/cli/__init__.py +++ b/glotaran/cli/__init__.py @@ -0,0 +1 @@ +from glotaran.cli.main import main diff --git a/glotaran/cli/commands/optimize.py b/glotaran/cli/commands/optimize.py index 643f0d066..6236e0be0 100644 --- a/glotaran/cli/commands/optimize.py +++ b/glotaran/cli/commands/optimize.py @@ -5,8 +5,8 @@ from glotaran.analysis.optimize import optimize from glotaran.cli.commands import util -from glotaran.io import save_result from glotaran.plugin_system.data_io_registration import known_data_formats +from glotaran.plugin_system.project_io_registration import save_result from glotaran.project.scheme import Scheme @@ -32,6 +32,14 @@ help="Path to an output directory.", show_default=True, ) +@click.option( + "--outformat", + "-ofmt", + default="folder", + type=click.Choice(util.project_io_list_supporting_plugins("save_result", ("yml_str"))), + help="The format of the output.", + show_default=True, +) @click.option( "--nfev", "-n", @@ -40,13 +48,14 @@ help="Maximum number of function evaluations.", show_default=True, ) -@click.option("--nnls", is_flag=True, help="Use non-negative least squares.") +@click.option("--nnls", is_flag=True, default=False, help="Use non-negative least squares.") @click.option("--yes", "-y", is_flag=True, help="Don't ask for confirmation.") @util.signature_analysis def optimize_cmd( dataformat: str, data: typing.List[str], out: str, + outformat: str, nfev: int, nnls: bool, yes: bool, @@ -62,7 +71,9 @@ def optimize_cmd( if scheme_file is not None: scheme = util.load_scheme_file(scheme_file, verbose=True) if nfev is not None: - scheme.nfev = nfev + scheme.maximum_number_function_evaluations = nfev + + scheme.non_negative_least_squares = nnls else: if model_file is None: click.echo("Error: Neither scheme nor model specified", err=True) @@ -100,14 +111,6 @@ def optimize_cmd( click.echo(f"Saving directory: is '{out if out is not None else 'None'}'") if yes or click.confirm("Do you want to start optimization?", abort=True, default=True): - # try: - # click.echo('Preparing optimization...', nl=False) - # optimizer = gta.analysis.optimizer.Optimizer(scheme) - # click.echo(' Success') - # except Exception as e: - # click.echo(" Error") - # click.echo(e, err=True) - # sys.exit(1) try: click.echo("Optimizing...") result = optimize(scheme) @@ -123,7 +126,7 @@ def optimize_cmd( try: click.echo(f"Saving directory is '{out}'") if yes or click.confirm("Do you want to save the data?", default=True): - save_result(result_path=out, format_name="yml", result=result) + save_result(result_path=out, format_name=outformat, result=result) click.echo("File saving successful.") except Exception as e: click.echo(f"An error occurred during saving: \n\n{e}", err=True) diff --git a/glotaran/cli/commands/pluginlist.py b/glotaran/cli/commands/pluginlist.py index 58a9e4607..e970de91b 100644 --- a/glotaran/cli/commands/pluginlist.py +++ b/glotaran/cli/commands/pluginlist.py @@ -1,21 +1,25 @@ +from textwrap import dedent + import click -from glotaran.model import known_model_names from glotaran.plugin_system.data_io_registration import known_data_formats +from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names from glotaran.plugin_system.project_io_registration import known_project_formats def plugin_list_cmd(): """Prints a list of installed plugins.""" - output = """ - Installed Glotaran Plugins: + output = dedent( + """ + Installed Glotaran Plugins: - Models: - """ + Megacomplex Models: + """ + ) output += "\n" - for name in known_model_names(): + for name in known_megacomplex_names(): output += f" * {name}\n" output += "\nData file Formats\n\n" diff --git a/glotaran/cli/commands/test/test_util.py b/glotaran/cli/commands/test/test_util.py new file mode 100644 index 000000000..6a9cce9da --- /dev/null +++ b/glotaran/cli/commands/test/test_util.py @@ -0,0 +1,10 @@ +from glotaran.cli.commands.util import project_io_list_supporting_plugins + + +def test_project_io_list_supporting_plugins_save_result(): + """Same as used in ``--outformat`` CLI option.""" + result = project_io_list_supporting_plugins("save_result", ("yml_str")) + + assert "csv" not in result + assert "yml_str" not in result + assert "folder" in result diff --git a/glotaran/cli/commands/util.py b/glotaran/cli/commands/util.py index 27c5c3551..ef83b42ec 100644 --- a/glotaran/cli/commands/util.py +++ b/glotaran/cli/commands/util.py @@ -1,10 +1,20 @@ +from __future__ import annotations + import sys +from typing import Iterable import click from click import echo from click import prompt -import glotaran as gta +from glotaran.io import ProjectIoInterface +from glotaran.io import load_dataset +from glotaran.io import load_model +from glotaran.io import load_parameters +from glotaran.io import load_scheme +from glotaran.plugin_system.base_registry import methods_differ_from_baseclass_table +from glotaran.plugin_system.project_io_registration import get_project_io +from glotaran.plugin_system.project_io_registration import known_project_formats def signature_analysis(cmd): @@ -46,23 +56,25 @@ def _load_file(filename, loader, name, verbose): def load_scheme_file(filename, verbose=False): - return _load_file(filename, gta.analysis.scheme.Scheme.from_yaml_file, "scheme", verbose) + return _load_file( + filename, lambda file: load_scheme(file, format_name="yml"), "scheme", verbose + ) def load_model_file(filename, verbose=False): - return _load_file(filename, gta.read_model_from_yaml_file, "model", verbose) + return _load_file(filename, lambda file: load_model(file, format_name="yml"), "model", verbose) def load_parameter_file(filename, fmt=None, verbose=False): def loader(filename): - return gta.parameter.ParameterGroup.from_file(filename, fmt=fmt) + return load_parameters(filename, format_name=fmt) return _load_file(filename, loader, "parameter", verbose) def load_dataset_file(filename, fmt=None, verbose=False): def loader(filename): - return gta.io.read_data_file(filename, fmt=fmt) + return load_dataset(filename, format_name=fmt) return _load_file(filename, loader, "parameter", verbose) @@ -116,6 +128,31 @@ def write_data(data, out): df.to_csv(out) +def project_io_list_supporting_plugins( + method_name: str, block_list: Iterable[str] | None = None +) -> Iterable[str]: + """List all project-io plugin that implement ``method_name``. + + Parameters + ---------- + method_name: str + Name of the method which should be supported. + block_list: Iterable[str] + Iterable of plugin names which should be omitted. + """ + if block_list is None: + block_list = [] + support_table = methods_differ_from_baseclass_table( + method_names=method_name, + plugin_registry_keys=known_project_formats(full_names=False), + get_plugin_function=get_project_io, + base_class=ProjectIoInterface, + ) + support_table = filter(lambda entry: entry[1], support_table) + supporting_list: Iterable[str] = (entry[0].replace("`", "") for entry in support_table) + return list(filter(lambda entry: entry not in block_list, supporting_list)) + + class ValOrRangeOrList(click.ParamType): name = "number or range or list" diff --git a/glotaran/cli/main.py b/glotaran/cli/main.py index 34124ffa2..cf516e727 100644 --- a/glotaran/cli/main.py +++ b/glotaran/cli/main.py @@ -1,6 +1,6 @@ import click -import glotaran as gta +from glotaran import __version__ as VERSION from glotaran.cli.commands.optimize import optimize_cmd from glotaran.cli.commands.pluginlist import plugin_list_cmd from glotaran.cli.commands.print import print_cmd @@ -8,6 +8,8 @@ class Cli(click.Group): + """The glotaran CLI implementation of :class:`click.group`""" + def __init__(self, *args, **kwargs): self.help_priorities = {} super().__init__(*args, **kwargs) @@ -42,32 +44,31 @@ def decorator(f): @click.group(cls=Cli) -@click.version_option(version=gta.__version__) -def glotaran(): +@click.version_option(version=VERSION) +def main(prog_name="glotaran"): + """The glotaran CLI main function.""" pass -glotaran.add_command( - glotaran.command( +main.add_command( + main.command( name="pluginlist", short_help="Prints a list of installed plugins.", help_priority=4 )(plugin_list_cmd) ) -glotaran.add_command( - glotaran.command(name="print", short_help="Prints a model as markdown.", help_priority=3)( +main.add_command( + main.command(name="print", short_help="Prints a model as markdown.", help_priority=3)( print_cmd ) ) -glotaran.add_command( - glotaran.command(name="validate", short_help="Validates a model file.", help_priority=2)( +main.add_command( + main.command(name="validate", short_help="Validates a model file.", help_priority=2)( validate_cmd ) ) -glotaran.add_command( - glotaran.command(name="optimize", short_help="Optimizes a model.", help_priority=1)( - optimize_cmd - ) +main.add_command( + main.command(name="optimize", short_help="Optimizes a model.", help_priority=1)(optimize_cmd) ) if __name__ == "__main__": - glotaran() + raise SystemExit(main(prog_name="glotaran")) diff --git a/glotaran/cli/test/test_cli.py b/glotaran/cli/test/test_cli.py new file mode 100644 index 000000000..76922ef05 --- /dev/null +++ b/glotaran/cli/test/test_cli.py @@ -0,0 +1,43 @@ +from pathlib import Path + +from click.testing import CliRunner + +from glotaran.cli import main + + +def test_cli_help(): + """Test the CLI help options.""" + runner = CliRunner() + result = runner.invoke(main) + assert result.exit_code == 0 + help_result = runner.invoke(main, ["--help"], prog_name="glotaran") + assert help_result.exit_code == 0 + assert "Usage: glotaran [OPTIONS] COMMAND [ARGS]..." in help_result.output + + +def test_cli_pluginlist(): + """Test the CLI pluginlist option.""" + runner = CliRunner() + result = runner.invoke(main, ["pluginlist"], prog_name="glotaran") + assert result.exit_code == 0 + assert "Installed Glotaran Plugins" in result.output + + +def test_cli_validate_parameters_file(tmp_path: Path): + """Test the CLI pluginlist option.""" + empty_file = tmp_path.joinpath("empty_file.yml") + empty_file.touch() + runner = CliRunner() + result_ok = runner.invoke( + main, ["validate", "--parameters_file", str(empty_file)], prog_name="glotaran" + ) + assert result_ok.exit_code == 0 + assert "Type 'glotaran validate --help' for more info." in result_ok.output + non_existing_file = tmp_path.joinpath("_does_not_exist_.yml") + result_file_not_exist = runner.invoke( + main, ["validate", "--parameters_file", str(non_existing_file)], prog_name="glotaran" + ) + assert result_file_not_exist.exit_code == 2 + assert all( + substring in result_file_not_exist.output for substring in ("Error", "does not exist") + ) diff --git a/glotaran/project/scheme.py b/glotaran/project/scheme.py index b1c73c5f5..3376a5dc5 100644 --- a/glotaran/project/scheme.py +++ b/glotaran/project/scheme.py @@ -37,7 +37,7 @@ class Scheme: group: bool | None = None group_tolerance: float = 0.0 non_negative_least_squares: bool = False - maximum_number_function_evaluations: int = None + maximum_number_function_evaluations: int | None = None add_svd: bool = True ftol: float = 1e-8 gtol: float = 1e-8 diff --git a/setup.cfg b/setup.cfg index 3175b7a64..9b3a44fbc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ zip_safe = True [options.entry_points] console_scripts = - glotaran=glotaran.cli.main:glotaran + glotaran=glotaran.cli.main:main glotaran.plugins.data_io = ascii = glotaran.builtin.io.ascii.wavelength_time_explicit_file sdt = glotaran.builtin.io.sdt.sdt_file_reader From d2b297b534dd7092eb3c00d69132da7c0e4971ee Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Fri, 13 Aug 2021 20:48:46 +0200 Subject: [PATCH 10/29] =?UTF-8?q?=F0=9F=A9=B9Fix=20check=5Fdeprecations=20?= =?UTF-8?q?not=20showing=20deprecation=20warnings=20(#775)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Added 'deprecate_dict_entry' to deprecate dict keys and/or values * 🧹📚 Changed Warns OverDueDeprecation to Raises and added missing Raises OverDueDeprecation to deprecate_module_attribute and deprecate_submodule * 🩹🧪 Reimplemented check_deprecations and with deprecate_dict_entry and renamed it to 'model_spec_deprecations' This change also ensures that the deprecation warning is thrown when users calls 'load_model' so python will show the warning * ♻️🧪 Changed deprecation_warning_on_call_test_helper to test for file emitting the warning and return the WarningsRecorder * 📚 Added docs for deprecate_dict_entry * 🧹 Removed unused record variables * 🧪 Added test for deprecated 'spectral_constraints' * 👌 Changed DeprecationWarning to GlotaranApiDeprecationWarning (GlotaranApiDeprecationWarning is a subclass of UserWarning so it won't be filtered automatically by python) * 📚 Updated docs to reflect the usage of GlotaranApiDeprecationWarning --- CONTRIBUTING.rst | 16 +- glotaran/builtin/io/yml/yml.py | 4 +- glotaran/deprecation/__init__.py | 1 + glotaran/deprecation/deprecation_utils.py | 172 +++++++++++++++++- .../deprecation/modules/builtin_io_yml.py | 87 +++++++++ glotaran/deprecation/modules/test/__init__.py | 26 ++- .../modules/test/test_builtin_io_yml.py | 87 +++++++++ .../modules/test/test_changed_imports.py | 3 +- .../modules/test/test_glotaran_root.py | 17 +- .../modules/test/test_project_result.py | 2 +- .../modules/test/test_project_sheme.py | 2 +- .../test/test_deprecation_utils.py | 163 ++++++++++++++++- glotaran/utils/sanitize.py | 81 --------- tox.ini | 2 +- 14 files changed, 546 insertions(+), 117 deletions(-) create mode 100644 glotaran/deprecation/modules/builtin_io_yml.py create mode 100644 glotaran/deprecation/modules/test/test_builtin_io_yml.py diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 391d52ded..991396fe9 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -175,12 +175,15 @@ To make deprecations as robust as possible and give users all needed information to adjust their code, we provide helper functions inside the module :mod:`glotaran.deprecation`. +.. currentmodule:: glotaran.deprecation.deprecation_utils + The functions you most likely want to use are * :func:`deprecate` for functions, methods and classes * :func:`warn_deprecated` for call arguments * :func:`deprecate_module_attribute` for module attributes * :func:`deprecate_submodule` for modules +* :func:`deprecate_dict_entry` for dict entries Those functions not only make it easier to deprecate something, but they also check that @@ -193,7 +196,7 @@ provides the test helper functions ``deprecation_warning_on_call_test_helper`` a Since the tests for deprecation are mainly for maintainability and not to test the functionality (those tests should be in the appropriate place) ``deprecation_warning_on_call_test_helper`` will by default just test that a -``DeprecationWarning`` was raised and ignore all raise ``Exception`` s. +``GlotaranApiDeprecationWarning`` was raised and ignore all raise ``Exception`` s. An exception to this rule is when adding back removed functionality (which shouldn't happen in the first place but might), which should be implemented in a file under ``glotaran/deprecation/modules`` and filenames should be like the @@ -300,6 +303,17 @@ as an attribute to the parent package. to_be_removed_in_version="0.6.0", ) +Deprecating dict entries +~~~~~~~~~~~~~~~~~~~~~~~~ +The possible dict deprecation actions are: + +- Swapping of keys ``{"foo": 1} -> {"bar": 1}`` (done via ``swap_keys=("foo", "bar")``) +- Replacing of matching values ``{"foo": 1} -> {"foo": 2}`` (done via ``replace_rules=({"foo": 1}, {"foo": 2})``) +- Replacing of matching values and swapping of keys ``{"foo": 1} -> {"bar": 2}`` (done via ``replace_rules=({"foo": 1}, {"bar": 2})``) + +For full examples have a look at the examples from the docstring (:func:`deprecate_dict_entry`). + + Deploying --------- diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index f0b0a42df..e45ad6cca 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -7,6 +7,7 @@ import yaml +from glotaran.deprecation.modules.builtin_io_yml import model_spec_deprecations from glotaran.io import ProjectIoInterface from glotaran.io import load_dataset from glotaran.io import load_model @@ -19,7 +20,6 @@ from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme -from glotaran.utils.sanitize import check_deprecations from glotaran.utils.sanitize import sanitize_yaml if TYPE_CHECKING: @@ -49,7 +49,7 @@ def load_model(self, file_name: str) -> Model: with open(file_name) as f: spec = yaml.safe_load(f) - check_deprecations(spec) + model_spec_deprecations(spec) spec = sanitize_yaml(spec) diff --git a/glotaran/deprecation/__init__.py b/glotaran/deprecation/__init__.py index fd4578f55..14edceab3 100644 --- a/glotaran/deprecation/__init__.py +++ b/glotaran/deprecation/__init__.py @@ -1,5 +1,6 @@ """Deprecation helpers and place to put deprecated implementations till removing.""" from glotaran.deprecation.deprecation_utils import deprecate +from glotaran.deprecation.deprecation_utils import deprecate_dict_entry from glotaran.deprecation.deprecation_utils import deprecate_module_attribute from glotaran.deprecation.deprecation_utils import deprecate_submodule from glotaran.deprecation.deprecation_utils import warn_deprecated diff --git a/glotaran/deprecation/deprecation_utils.py b/glotaran/deprecation/deprecation_utils.py index 8f9c18db1..d6547116b 100644 --- a/glotaran/deprecation/deprecation_utils.py +++ b/glotaran/deprecation/deprecation_utils.py @@ -11,10 +11,15 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Hashable +from typing import Mapping +from typing import MutableMapping from typing import TypeVar from typing import cast from warnings import warn +import numpy as np + DecoratedCallable = TypeVar( "DecoratedCallable", bound=Callable[..., Any] ) # decorated function or class @@ -32,6 +37,20 @@ class OverDueDeprecation(Exception): warn_deprecated deprecate_module_attribute deprecate_submodule + deprecate_dict_entry + """ + + +class GlotaranApiDeprecationWarning(UserWarning): + """Warning to give users about API changes. + + See Also + -------- + deprecate + warn_deprecated + deprecate_module_attribute + deprecate_submodule + deprecate_dict_entry """ @@ -166,8 +185,8 @@ def warn_deprecated( will import ``package.module.class`` and check if ``class`` has an attribute ``mapping``. - Warns - ----- + Raises + ------ OverDueDeprecation If the current version is greater or equal to ``end_of_life_version``. @@ -212,7 +231,7 @@ def read_parameters_from_yaml_file(model_path: str): selected_indices = importable_indices[: len(selected_qual_names)] check_qualnames_in_tests(qual_names=selected_qual_names, importable_indices=selected_indices) warn( - DeprecationWarning( + GlotaranApiDeprecationWarning( f"Usage of {deprecated_qual_name_usage!r} was deprecated, " f"use {new_qual_name_usage!r} instead.\n" f"This usage will be an error in version: {to_be_removed_in_version!r}." @@ -265,8 +284,8 @@ def deprecate( DecoratedCallable Original function or class throwing a Deprecation warning when used. - Warns - ----- + Raises + ------ OverDueDeprecation If the current version is greater or equal to ``end_of_life_version``. @@ -334,6 +353,134 @@ def outer_wrapper(deprecated_object: DecoratedCallable) -> DecoratedCallable: return cast(Callable[[DecoratedCallable], DecoratedCallable], outer_wrapper) +def deprecate_dict_entry( + *, + dict_to_check: MutableMapping[Hashable, Any], + deprecated_usage: str, + new_usage: str, + to_be_removed_in_version: str, + swap_keys: tuple[Hashable, Hashable] | None = None, + replace_rules: tuple[Mapping[Hashable, Any], Mapping[Hashable, Any]] | None = None, + stacklevel: int = 3, +) -> None: + """Replace dict entry inplace and warn about usage change, if present in the dict. + + Parameters + ---------- + dict_to_check : MutableMapping[Hashable, Any] + Dict which should be checked. + deprecated_usage : str + Old usage to inform user (only used in warning). + new_usage : str + New usage to inform user (only used in warning). + to_be_removed_in_version : str + Version the support for this usage will be removed. + swap_keys : tuple[Hashable, Hashable] + (old_key, new_key), + ``dict_to_check[new_key]`` will be assigned the value ``dict_to_check[old_key]`` + and ``old_key`` will be removed from the dict. + by default None + replace_rules : Mapping[Hashable, tuple[Any, Any]] + ({old_key: old_value}, {new_key: new_value}), + If ``dict_to_check[old_key]`` has the value ``old_value``, + ``dict_to_check[new_key]`` it will be set to ``new_value``. + ``old_key`` will be removed from the dict if ``old_key`` and ``new_key`` aren't equal. + by default None + stacklevel : int + Stack at which the warning should be shown as raise. , by default 3 + + + Raises + ------ + ValueError + If both ``swap_keys`` and ``replace_rules`` are None (default) or not None. + OverDueDeprecation + If the current version is greater or equal to ``end_of_life_version``. + + See Also + -------- + warn_deprecated + + Notes + ----- + To prevent confusion exactly one of ``replace_rules`` and ``swap_keys`` + needs to be passed. + + Examples + -------- + For readability sake the warnings won't be shown in the examples. + + Swapping key names: + + >>> dict_to_check = {"foo": 123} + >>> deprecate_dict_entry( + dict_to_check=dict_to_check, + deprecated_usage="foo", + new_usage="bar", + to_be_removed_in_version="0.6.0", + swap_keys=("foo", "bar") + ) + >>> dict_to_check + {"bar": 123} + + Changing values: + + >>> dict_to_check = {"foo": 123} + >>> deprecate_dict_entry( + dict_to_check=dict_to_check, + deprecated_usage="foo: 123", + new_usage="foo: 123.0", + to_be_removed_in_version="0.6.0", + replace_rules=({"foo": 123}, {"foo": 123.0}) + ) + >>> dict_to_check + {"foo": 123.0} + + Swapping key names AND changing values: + + >>> dict_to_check = {"type": "kinetic-spectrum"} + >>> deprecate_dict_entry( + dict_to_check=dict_to_check, + deprecated_usage="type: kinectic-spectrum", + new_usage="default-megacomplex: decay", + to_be_removed_in_version="0.6.0", + replace_rules=({"type": "kinetic-spectrum"}, {"default-megacomplex": "decay"}) + ) + >>> dict_to_check + {"default-megacomplex": "decay"} + + + .. # noqa: DAR402 + """ + dict_changed = False + + if not np.logical_xor(swap_keys is None, replace_rules is None): + raise ValueError( + "Exactly one of the parameters `swap_keys` or `replace_rules` needs to be provided." + ) + if swap_keys is not None and swap_keys[0] in dict_to_check: + dict_changed = True + dict_to_check[swap_keys[1]] = dict_to_check[swap_keys[0]] + del dict_to_check[swap_keys[0]] + if replace_rules is not None: + old_key, old_value = next(iter(replace_rules[0].items())) + new_key, new_value = next(iter(replace_rules[1].items())) + if old_key in dict_to_check and dict_to_check[old_key] == old_value: + dict_changed = True + dict_to_check[new_key] = new_value + if new_key != old_key: + del dict_to_check[old_key] + + if dict_changed: + warn_deprecated( + deprecated_qual_name_usage=deprecated_usage, + new_qual_name_usage=new_usage, + to_be_removed_in_version=to_be_removed_in_version, + stacklevel=stacklevel, + check_qual_names=(False, False), + ) + + def module_attribute(module_qual_name: str, attribute_name: str) -> Any: """Import and return the attribute (e.g. function or class) of a module. @@ -384,6 +531,11 @@ def deprecate_module_attribute( Any Module attribute from its new location. + Raises + ------ + OverDueDeprecation + If the current version is greater or equal to ``end_of_life_version``. + See Also -------- deprecate @@ -412,6 +564,8 @@ def __getattr__(attribute_name: str): raise AttributeError(f"module {__name__} has no attribute {attribute_name}") + + .. # noqa: DAR402 """ module_name = ".".join(new_qual_name.split(".")[:-1]) attribute_name = new_qual_name.split(".")[-1] @@ -457,6 +611,11 @@ def deprecate_submodule( ModuleType Module containing + Raises + ------ + OverDueDeprecation + If the current version is greater or equal to ``end_of_life_version``. + See Also -------- deprecate @@ -478,6 +637,9 @@ def deprecate_submodule( new_module_name="glotaran.project.result", to_be_removed_in_version="0.6.0", ) + + + .. # noqa: DAR402 """ new_module = import_module(new_module_name) deprecated_module = ModuleType( diff --git a/glotaran/deprecation/modules/builtin_io_yml.py b/glotaran/deprecation/modules/builtin_io_yml.py new file mode 100644 index 000000000..635a5a9a9 --- /dev/null +++ b/glotaran/deprecation/modules/builtin_io_yml.py @@ -0,0 +1,87 @@ +"""Deprecation functions for the yaml parser.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from glotaran.deprecation import deprecate_dict_entry + +if TYPE_CHECKING: + from typing import Any + from typing import MutableMapping + + +def model_spec_deprecations(spec: MutableMapping[Any, Any]) -> None: + """Check deprecations in the model specification ``spec`` dict. + + Parameters + ---------- + spec : MutableMapping[Any, Any] + Model specification dictionary + """ + load_model_stack_level = 7 + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="type: kinetic-spectrum", + new_usage="default-megacomplex: decay", + to_be_removed_in_version="0.7.0", + replace_rules=({"type": "kinetic-spectrum"}, {"default-megacomplex": "decay"}), + stacklevel=load_model_stack_level, + ) + + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="type: spectrum", + new_usage="default-megacomplex: spectral", + to_be_removed_in_version="0.7.0", + replace_rules=({"type": "spectrum"}, {"default-megacomplex": "spectral"}), + stacklevel=load_model_stack_level, + ) + + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="spectral_relations", + new_usage="relations", + to_be_removed_in_version="0.7.0", + swap_keys=("spectral_relations", "relations"), + stacklevel=load_model_stack_level, + ) + + if "relations" in spec: + for relation in spec["relations"]: + deprecate_dict_entry( + dict_to_check=relation, + deprecated_usage="compartment", + new_usage="source", + to_be_removed_in_version="0.7.0", + swap_keys=("compartment", "source"), + stacklevel=load_model_stack_level, + ) + + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="spectral_constraints", + new_usage="constraints", + to_be_removed_in_version="0.7.0", + swap_keys=("spectral_constraints", "constraints"), + stacklevel=load_model_stack_level, + ) + + if "constraints" in spec: + for constraint in spec["constraints"]: + deprecate_dict_entry( + dict_to_check=constraint, + deprecated_usage="constraint.compartment", + new_usage="constraint.target", + to_be_removed_in_version="0.7.0", + swap_keys=("compartment", "target"), + stacklevel=load_model_stack_level, + ) + + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="equal_area_penalties", + new_usage="clp_area_penalties", + to_be_removed_in_version="0.7.0", + swap_keys=("equal_area_penalties", "clp_area_penalties"), + stacklevel=load_model_stack_level, + ) diff --git a/glotaran/deprecation/modules/test/__init__.py b/glotaran/deprecation/modules/test/__init__.py index dfff9c230..9e67df7e7 100644 --- a/glotaran/deprecation/modules/test/__init__.py +++ b/glotaran/deprecation/modules/test/__init__.py @@ -1,16 +1,21 @@ """Package with deprecation tests and helper functions.""" from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING import pytest +from glotaran.deprecation.deprecation_utils import GlotaranApiDeprecationWarning + if TYPE_CHECKING: from typing import Any from typing import Callable from typing import Mapping from typing import Sequence + from _pytest.recwarn import WarningsRecorder + def deprecation_warning_on_call_test_helper( deprecated_callable: Callable[..., Any], @@ -18,7 +23,7 @@ def deprecation_warning_on_call_test_helper( raise_exception=False, args: Sequence[Any] = [], kwargs: Mapping[str, Any] = {}, -) -> Any: +) -> tuple[WarningsRecorder, Any]: """Helperfunction to quickly test that a deprecated class or function warns. By default this ignores error when calling the function/class, @@ -41,17 +46,24 @@ def deprecation_warning_on_call_test_helper( Returns ------- - Any - Return value of deprecated_callable + tuple[WarningsRecorder, Any] + Tuple of the WarningsRecorder and return value of deprecated_callable Raises ------ Exception Exception caused by deprecated_callable if raise_exception is True. """ - with pytest.warns(DeprecationWarning): + with pytest.warns(GlotaranApiDeprecationWarning) as record: try: - return deprecated_callable(*args, **kwargs) - except Exception: + result = deprecated_callable(*args, **kwargs) + + assert len(record) >= 1 + assert Path(record[0].filename) == Path(__file__) + + return record, result + + except Exception as e: if raise_exception: - raise + raise e + return record, None diff --git a/glotaran/deprecation/modules/test/test_builtin_io_yml.py b/glotaran/deprecation/modules/test/test_builtin_io_yml.py new file mode 100644 index 000000000..26202b947 --- /dev/null +++ b/glotaran/deprecation/modules/test/test_builtin_io_yml.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from textwrap import dedent +from typing import TYPE_CHECKING + +import pytest + +import glotaran.builtin.io.yml.yml as yml_module +from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper +from glotaran.io import load_model + +if TYPE_CHECKING: + from typing import Any + + from _pytest.monkeypatch import MonkeyPatch + + +@pytest.mark.parametrize( + "model_yml_str, expected_nr_of_warnings, expected_key, expected_value", + ( + ("type: kinetic-spectrum", 1, "default-megacomplex", "decay"), + ("type: spectrum", 1, "default-megacomplex", "spectral"), + ( + dedent( + """ + spectral_relations: + - compartment: s1 + - compartment: s2 + """ + ), + 3, + "relations", + [{"source": "s1"}, {"source": "s2"}], + ), + ( + dedent( + """ + spectral_constraints: + - compartment: s1 + - compartment: s2 + """ + ), + 3, + "constraints", + [{"target": "s1"}, {"target": "s2"}], + ), + ( + dedent( + """ + equal_area_penalties: + - type: equal_area + """ + ), + 1, + "clp_area_penalties", + [{"type": "equal_area"}], + ), + ), + ids=( + "type: kinetic-spectrum", + "type: spectrum", + "spectral_relations", + "spectral_constraints", + "equal_area_penalties", + ), +) +def test_model_spec_deprecations( + monkeypatch: MonkeyPatch, + model_yml_str: str, + expected_nr_of_warnings: int, + expected_key: str, + expected_value: Any, +): + """Warning gets emitted by load_model""" + return_dicts = [] + with monkeypatch.context() as m: + m.setattr(yml_module, "sanitize_yaml", lambda spec: return_dicts.append(spec)) + record, _ = deprecation_warning_on_call_test_helper( + load_model, args=(model_yml_str,), kwargs={"format_name": "yml_str"} + ) + + return_dict = return_dicts[0] + + assert expected_key in return_dict + assert return_dict[expected_key] == expected_value + + assert len(record) == expected_nr_of_warnings diff --git a/glotaran/deprecation/modules/test/test_changed_imports.py b/glotaran/deprecation/modules/test/test_changed_imports.py index 1911cc80d..2c7673bd7 100644 --- a/glotaran/deprecation/modules/test/test_changed_imports.py +++ b/glotaran/deprecation/modules/test/test_changed_imports.py @@ -6,6 +6,7 @@ import pytest +from glotaran.deprecation.deprecation_utils import GlotaranApiDeprecationWarning from glotaran.deprecation.deprecation_utils import module_attribute from glotaran.io import load_dataset from glotaran.parameter import ParameterGroup @@ -24,7 +25,7 @@ def check_recwarn(records: WarningsRecorder, warn_nr=1): print(record) assert len(records) == warn_nr - assert records[0].category == DeprecationWarning + assert records[0].category == GlotaranApiDeprecationWarning records.clear() diff --git a/glotaran/deprecation/modules/test/test_glotaran_root.py b/glotaran/deprecation/modules/test/test_glotaran_root.py index bb02eab34..f1a8f9cc6 100644 --- a/glotaran/deprecation/modules/test/test_glotaran_root.py +++ b/glotaran/deprecation/modules/test/test_glotaran_root.py @@ -11,6 +11,7 @@ from glotaran import read_parameters_from_csv_file from glotaran import read_parameters_from_yaml from glotaran import read_parameters_from_yaml_file +from glotaran.deprecation.deprecation_utils import GlotaranApiDeprecationWarning from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper from glotaran.model import Model from glotaran.parameter import ParameterGroup @@ -20,7 +21,7 @@ def dummy_warn(foo, bar=False): - warn(DeprecationWarning("foo")) + warn(GlotaranApiDeprecationWarning("foo"), stacklevel=2) if not isinstance(bar, bool): raise ValueError("not a bool") return foo, bar @@ -32,10 +33,10 @@ def dummy_no_warn(foo, bar=False): def test_deprecation_warning_on_call_test_helper(): """Correct result passed on""" - result = deprecation_warning_on_call_test_helper( + record, result = deprecation_warning_on_call_test_helper( dummy_warn, args=["foo"], kwargs={"bar": True} ) - + assert len(record) == 1 assert result == ("foo", True) @@ -60,7 +61,7 @@ def test_read_model_from_yaml(): type: kinetic-spectrum megacomplex: {} """ - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( read_model_from_yaml, args=[yaml], raise_exception=True ) @@ -75,7 +76,7 @@ def test_read_model_from_yaml_file(tmp_path: Path): """ model_file = tmp_path / "model.yaml" model_file.write_text(yaml) - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( read_model_from_yaml_file, args=[str(model_file)], raise_exception=True ) @@ -86,7 +87,7 @@ def test_read_parameters_from_csv_file(tmp_path: Path): """read_parameters_from_csv_file raises warning""" parameters_file = tmp_path / "parameters.csv" parameters_file.write_text("label,value\nfoo,123") - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( read_parameters_from_csv_file, args=[str(parameters_file)], raise_exception=True, @@ -98,7 +99,7 @@ def test_read_parameters_from_csv_file(tmp_path: Path): def test_read_parameters_from_yaml(): """read_parameters_from_yaml raises warning""" - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( read_parameters_from_yaml, args=["foo:\n - 123"], raise_exception=True ) @@ -111,7 +112,7 @@ def test_read_parameters_from_yaml_file(tmp_path: Path): """read_parameters_from_yaml_file raises warning""" parameters_file = tmp_path / "parameters.yaml" parameters_file.write_text("foo:\n - 123") - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( read_parameters_from_yaml_file, args=[str(parameters_file)], raise_exception=True ) diff --git a/glotaran/deprecation/modules/test/test_project_result.py b/glotaran/deprecation/modules/test/test_project_result.py index 7dc9ec6a4..83d2c5248 100644 --- a/glotaran/deprecation/modules/test/test_project_result.py +++ b/glotaran/deprecation/modules/test/test_project_result.py @@ -33,7 +33,7 @@ def test_Result_save_method(tmpdir: LocalPath, dummy_result: Result): # noqa: F def test_Result_get_dataset_method(dummy_result: Result): # noqa: F811 """Result.get_dataset(dataset_label) gives correct dataset.""" - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( dummy_result.get_dataset, args=["dataset1"], raise_exception=True ) diff --git a/glotaran/deprecation/modules/test/test_project_sheme.py b/glotaran/deprecation/modules/test/test_project_sheme.py index 93ba18793..42ce6daa1 100644 --- a/glotaran/deprecation/modules/test/test_project_sheme.py +++ b/glotaran/deprecation/modules/test/test_project_sheme.py @@ -46,7 +46,7 @@ def test_Scheme_from_yaml_file_method(tmp_path: Path): dataset1: {dataset_path}""" ) - result = deprecation_warning_on_call_test_helper( + _, result = deprecation_warning_on_call_test_helper( Scheme.from_yaml_file, args=[str(scheme_path)], raise_exception=True ) diff --git a/glotaran/deprecation/test/test_deprecation_utils.py b/glotaran/deprecation/test/test_deprecation_utils.py index 08db4e82c..786864d25 100644 --- a/glotaran/deprecation/test/test_deprecation_utils.py +++ b/glotaran/deprecation/test/test_deprecation_utils.py @@ -7,14 +7,20 @@ import pytest import glotaran +from glotaran.deprecation.deprecation_utils import GlotaranApiDeprecationWarning from glotaran.deprecation.deprecation_utils import OverDueDeprecation from glotaran.deprecation.deprecation_utils import deprecate +from glotaran.deprecation.deprecation_utils import deprecate_dict_entry from glotaran.deprecation.deprecation_utils import glotaran_version from glotaran.deprecation.deprecation_utils import module_attribute from glotaran.deprecation.deprecation_utils import parse_version from glotaran.deprecation.deprecation_utils import warn_deprecated if TYPE_CHECKING: + from typing import Any + from typing import Hashable + from typing import Mapping + from _pytest.monkeypatch import MonkeyPatch from _pytest.recwarn import WarningsRecorder @@ -79,7 +85,7 @@ def test_parse_version_errors(version_str: str): @pytest.mark.usefixtures("glotaran_0_3_0") def test_warn_deprecated(): """Warning gets shown when all is in order.""" - with pytest.warns(DeprecationWarning) as record: + with pytest.warns(GlotaranApiDeprecationWarning) as record: warn_deprecated( deprecated_qual_name_usage=DEPRECATION_QUAL_NAME, new_qual_name_usage=NEW_QUAL_NAME, @@ -169,7 +175,7 @@ def test_warn_deprecated_broken_qualname_no_check( deprecated_qual_name_usage: str, new_qual_name_usage: str, check_qualnames: tuple[bool, bool] ): """Not checking broken imports.""" - with pytest.warns(DeprecationWarning): + with pytest.warns(GlotaranApiDeprecationWarning): warn_deprecated( deprecated_qual_name_usage=deprecated_qual_name_usage, new_qual_name_usage=new_qual_name_usage, @@ -181,7 +187,7 @@ def test_warn_deprecated_broken_qualname_no_check( @pytest.mark.usefixtures("glotaran_0_3_0") def test_warn_deprecated_sliced_method(): """Slice away method for importing and check class for attribute""" - with pytest.warns(DeprecationWarning): + with pytest.warns(GlotaranApiDeprecationWarning): warn_deprecated( deprecated_qual_name_usage=( "glotaran.deprecation.test.test_deprecation_utils.DummyClass.foo()" @@ -195,7 +201,7 @@ def test_warn_deprecated_sliced_method(): @pytest.mark.usefixtures("glotaran_0_3_0") def test_warn_deprecated_sliced_mapping(): """Slice away mapping for importing and check class for attribute""" - with pytest.warns(DeprecationWarning): + with pytest.warns(GlotaranApiDeprecationWarning): warn_deprecated( deprecated_qual_name_usage=( "glotaran.deprecation.test.test_deprecation_utils.DummyClass.foo['bar']" @@ -238,7 +244,7 @@ def dummy(): assert dummy.__doc__ == "Dummy docstring for testing." assert len(recwarn) == 1 - assert recwarn[0].category == DeprecationWarning + assert recwarn[0].category == GlotaranApiDeprecationWarning assert recwarn[0].message.args[0] == DEPRECATION_WARN_MESSAGE # type: ignore [union-attr] assert Path(recwarn[0].filename) == Path(__file__) @@ -269,7 +275,7 @@ def from_string(cls, string: str): assert Foo.__doc__ == "Foo class docstring for testing." assert len(recwarn) == 1 - assert recwarn[0].category == DeprecationWarning + assert recwarn[0].category == GlotaranApiDeprecationWarning assert recwarn[0].message.args[0] == DEPRECATION_WARN_MESSAGE # type: ignore [union-attr] assert Path(recwarn[0].filename) == Path(__file__) @@ -278,6 +284,145 @@ def from_string(cls, string: str): assert len(recwarn) == 2 +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_swap_keys(): + """Replace old with new key while keeping the value.""" + test_dict = {"foo": 123} + with pytest.warns( + GlotaranApiDeprecationWarning, match="'foo'.+was deprecated, use 'bar'" + ) as record: + deprecate_dict_entry( + dict_to_check=test_dict, + deprecated_usage="foo", + new_usage="bar", + to_be_removed_in_version="0.6.0", + swap_keys=("foo", "bar"), + ) + + assert "bar" in test_dict + assert test_dict["bar"] == 123 + assert "foo" not in test_dict + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + + +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_replace_rules_only_values(): + """Replace old value for key with new value.""" + test_dict = {"foo": 123} + with pytest.warns( + GlotaranApiDeprecationWarning, match="'foo: 123'.+was deprecated, use 'foo: 321'" + ) as record: + deprecate_dict_entry( + dict_to_check=test_dict, + deprecated_usage="foo: 123", + new_usage="foo: 321", + to_be_removed_in_version="0.6.0", + replace_rules=({"foo": 123}, {"foo": 321}), + ) + + assert "foo" in test_dict + assert test_dict["foo"] == 321 + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + + +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_replace_rules_keys_and_values(): + """Replace old with new key AND replace old value for key with new value.""" + test_dict = {"foo": 123} + with pytest.warns( + GlotaranApiDeprecationWarning, match="'foo: 123'.+was deprecated, use 'bar: 321'" + ) as record: + deprecate_dict_entry( + dict_to_check=test_dict, + deprecated_usage="foo: 123", + new_usage="bar: 321", + to_be_removed_in_version="0.6.0", + replace_rules=({"foo": 123}, {"bar": 321}), + ) + + assert "bar" in test_dict + assert test_dict["bar"] == 321 + assert "foo" not in test_dict + + assert len(record) == 1 + assert Path(record[0].filename) == Path(__file__) + + +@pytest.mark.xfail(strict=True) +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_does_not_apply_swap_keys(): + """Don't warn if the dict doesn't change because old_key didn't match""" + + with pytest.warns( + GlotaranApiDeprecationWarning, match="'foo: 123'.+was deprecated, use 'foo: 321'" + ): + deprecate_dict_entry( + dict_to_check={"foo": 123}, + deprecated_usage="foo: 123", + new_usage="foo: 321", + to_be_removed_in_version="0.6.0", + swap_keys=("bar", "baz"), + ) + + +@pytest.mark.xfail(strict=True) +@pytest.mark.parametrize( + "replace_rules", + ( + ({"bar": 123}, {"bar": 321}), + ({"foo": 111}, {"bar": 321}), + ), +) +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_does_not_apply( + replace_rules: tuple[Mapping[Hashable, Any], Mapping[Hashable, Any]] +): + """Don't warn if the dict doesn't change because old_key or old_value didn't match""" + with pytest.warns( + GlotaranApiDeprecationWarning, match="'foo: 123'.+was deprecated, use 'foo: 321'" + ): + deprecate_dict_entry( + dict_to_check={"foo": 123}, + deprecated_usage="foo: 123", + new_usage="foo: 321", + to_be_removed_in_version="0.6.0", + replace_rules=replace_rules, + ) + + +@pytest.mark.parametrize( + "swap_keys, replace_rules", + ( + (None, None), + (("bar", "baz"), ({"bar": 1}, {"baz": 2})), + ), +) +@pytest.mark.usefixtures("glotaran_0_3_0") +def test_deprecate_dict_key_error_no_action( + swap_keys: tuple[Hashable, Hashable] | None, + replace_rules: tuple[Mapping[Hashable, Any], Mapping[Hashable, Any]] | None, +): + """Raise error if none or both `swap_keys` and `replace_rules` were provided.""" + with pytest.raises( + ValueError, + match=( + r"Exactly one of the parameters `swap_keys` or `replace_rules` needs to be provided\." + ), + ): + deprecate_dict_entry( + dict_to_check={}, + deprecated_usage="", + new_usage="", + to_be_removed_in_version="", + swap_keys=swap_keys, + replace_rules=replace_rules, + ) + + def test_module_attribute(): """Same code as the original import""" @@ -290,7 +435,7 @@ def test_module_attribute(): def test_deprecate_module_attribute(): """Same code as the original import and warning""" - with pytest.warns(DeprecationWarning) as record: + with pytest.warns(GlotaranApiDeprecationWarning) as record: from glotaran.deprecation.test.dummy_package.deprecated_module_attribute import ( deprecated_attribute, @@ -312,7 +457,7 @@ def test_deprecate_submodule(recwarn: WarningsRecorder): ) assert len(recwarn) == 1 - assert recwarn[0].category == DeprecationWarning + assert recwarn[0].category == GlotaranApiDeprecationWarning @pytest.mark.usefixtures("glotaran_0_3_0") @@ -324,7 +469,7 @@ def test_deprecate_submodule_from_import(recwarn: WarningsRecorder): ) assert len(recwarn) == 1 - assert recwarn[0].category == DeprecationWarning + assert recwarn[0].category == GlotaranApiDeprecationWarning assert Path(recwarn[0].filename) == Path(__file__) diff --git a/glotaran/utils/sanitize.py b/glotaran/utils/sanitize.py index d26a97d3e..8a3f40ebe 100644 --- a/glotaran/utils/sanitize.py +++ b/glotaran/utils/sanitize.py @@ -3,7 +3,6 @@ from typing import Any -from glotaran.deprecation import warn_deprecated from glotaran.utils.regex import RegexPattern as rp @@ -216,83 +215,3 @@ def sanitize_parameter_list(parameter_list: list[str | float]) -> list[str | flo parameter_list[i] = convert_scientific_to_float(value) return parameter_list - - -def check_deprecations(spec: dict): - """Check deprecations in a `spec` dict. - - Parameters - ---------- - spec : dict - A specification dictionary - """ - if "type" in spec: - if spec["type"] == "kinetic-spectrum": - warn_deprecated( - deprecated_qual_name_usage="type: kinectic-spectrum", - new_qual_name_usage="default-megacomplex: decay", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - spec["default-megacomplex"] = "decay" - elif spec["type"] == "spectral": - warn_deprecated( - deprecated_qual_name_usage="type: spectral", - new_qual_name_usage="default-megacomplex: spectral", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - spec["default-megacomplex"] = "spectral" - del spec["type"] - - if "spectral_relations" in spec: - warn_deprecated( - deprecated_qual_name_usage="spectral_relations", - new_qual_name_usage="relations", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - spec["relations"] = spec["spectral_relations"] - del spec["spectral_relations"] - - for i, relation in enumerate(spec["relations"]): - if "compartment" in relation: - warn_deprecated( - deprecated_qual_name_usage="relation.compartment", - new_qual_name_usage="relation.source", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - relation["source"] = relation["compartment"] - del relation["compartment"] - - if "spectral_constraints" in spec: - warn_deprecated( - deprecated_qual_name_usage="spectral_constraints", - new_qual_name_usage="constraints", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - spec["constraints"] = spec["spectral_constraints"] - del spec["spectral_constraints"] - - for i, constraint in enumerate(spec["constraints"]): - if "compartment" in constraint: - warn_deprecated( - deprecated_qual_name_usage="constraint.compartment", - new_qual_name_usage="constraint.target", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - constraint["target"] = constraint["compartment"] - del constraint["compartment"] - - if "equal_area_penalties" in spec: - warn_deprecated( - deprecated_qual_name_usage="equal_area_penalties", - new_qual_name_usage="clp_area_penalties", - to_be_removed_in_version="0.6.0", - check_qual_names=(False, False), - ) - spec["clp_area_penalties"] = spec["equal_area_penalties"] - del spec["equal_area_penalties"] diff --git a/tox.ini b/tox.ini index 6fe9c651c..4a4dfeee2 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ envlist = py{38}, pre-commit, docs, docs-notebooks, docs-links ; Uncomment to ignore deprecation warnings coming from pyglotaran ; (this helps to see the warnings from dependencies) ; filterwarnings = -; ignore:.+glotaran:DeprecationWarning +; ignore:.+glotaran:GlotaranApiDeprecationWarning [flake8] extend-ignore = E231, E203 From 998bad52565f26ce9c1c97f4cc813b085c75e416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Sun, 15 Aug 2021 16:05:25 +0200 Subject: [PATCH 11/29] =?UTF-8?q?=E2=9C=A8=20Damped=20Oscillation=20Megaco?= =?UTF-8?q?mplex=20(#764)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added DampedOscillationMegacomplex * Added damped-oscillation to setup.cfg plugins Co-authored-by: Joris Snellenburg --- glotaran/analysis/test/models.py | 42 +- glotaran/analysis/util.py | 21 +- .../damped_oscillation/__init__.py | 3 + .../damped_oscillation_megacomplex.py | 201 +++++++++ .../test/test_doas_model.py | 409 ++++++++++++++++++ .../megacomplexes/decay/decay_megacomplex.py | 7 +- glotaran/builtin/megacomplexes/decay/irf.py | 6 + .../decay/test/test_decay_megacomplex.py | 21 +- glotaran/builtin/megacomplexes/decay/util.py | 6 +- .../spectral/test/test_spectral_model.py | 21 +- glotaran/model/model.py | 6 +- glotaran/test/test_spectral_penalties.py | 26 +- setup.cfg | 1 + 13 files changed, 725 insertions(+), 45 deletions(-) create mode 100755 glotaran/builtin/megacomplexes/damped_oscillation/__init__.py create mode 100644 glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py create mode 100755 glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py diff --git a/glotaran/analysis/test/models.py b/glotaran/analysis/test/models.py index 004e3746e..9c56f7e58 100644 --- a/glotaran/analysis/test/models.py +++ b/glotaran/analysis/test/models.py @@ -56,13 +56,23 @@ def index_dependent(self, dataset_model): class SimpleTestModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "model_complex": SimpleTestMegacomplex, + "global_complex": SimpleTestMegacomplexGlobal, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "model_complex": SimpleTestMegacomplex, - "global_complex": SimpleTestMegacomplexGlobal, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) @@ -157,14 +167,24 @@ def finalize_data( class DecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "model_complex": SimpleKineticMegacomplex, + "global_complex": SimpleSpectralMegacomplex, + "global_complex_shaped": ShapedSpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "model_complex": SimpleKineticMegacomplex, - "global_complex": SimpleSpectralMegacomplex, - "global_complex_shaped": ShapedSpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index 25b44c3b2..d54cb06b0 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -71,15 +71,7 @@ def calculate_matrix( clp_labels = this_clp_labels matrix = this_matrix else: - tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] - tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) - for idx, label in enumerate(tmp_clp_labels): - if label in clp_labels: - tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] - if label in this_clp_labels: - tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] - clp_labels = tmp_clp_labels - matrix = tmp_matrix + clp_labels, matrix = combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels) if as_global_model: dataset_model.swap_dimensions() @@ -87,6 +79,17 @@ def calculate_matrix( return CalculatedMatrix(clp_labels, matrix) +def combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels): + tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] + tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) + for idx, label in enumerate(tmp_clp_labels): + if label in clp_labels: + tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] + if label in this_clp_labels: + tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] + return tmp_clp_labels, tmp_matrix + + @nb.jit(nopython=True, parallel=True) def apply_weight(matrix, weight): for i in nb.prange(matrix.shape[1]): diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py b/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py new file mode 100755 index 000000000..2f975d246 --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/__init__.py @@ -0,0 +1,3 @@ +from glotaran.builtin.megacomplexes.damped_oscillation.damped_oscillation_megacomplex import ( + DampedOscillationMegacomplex, +) diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py new file mode 100644 index 000000000..d57996a6f --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from typing import List + +import numba as nb +import numpy as np +import xarray as xr +from scipy.special import erf + +from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.model import megacomplex +from glotaran.model.item import model_item_validator +from glotaran.parameter import Parameter + + +@megacomplex( + dimension="time", + dataset_model_items={ + "irf": {"type": Irf, "allow_none": True}, + }, + properties={ + "labels": List[str], + "frequencies": List[Parameter], + "rates": List[Parameter], + }, + register_as="damped-oscillation", +) +class DampedOscillationMegacomplex(Megacomplex): + @model_item_validator(False) + def ensure_oscillation_paramater(self, model: Model) -> list[str]: + + problems = [] + + if len(self.labels) != len(self.frequencies) or len(self.labels) != len(self.rates): + problems.append( + f"Size of labels ({len(self.labels)}), frequencies ({len(self.frequencies)}) " + f"and rates ({len(self.rates)}) does not match for damped oscillation " + f"megacomplex '{self.label}'." + ) + + return problems + + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + + clp_label = [f"{label}_cos" for label in self.labels] + [ + f"{label}_sin" for label in self.labels + ] + + model_axis = dataset_model.get_model_axis() + delta = np.abs(model_axis[1:] - model_axis[:-1]) + delta_min = delta[np.argmin(delta)] + frequency_max = 1 / (2 * 0.03 * delta_min) + frequencies = np.array(self.frequencies) * 0.03 * 2 * np.pi + frequencies[frequencies >= frequency_max] = np.mod( + frequencies[frequencies >= frequency_max], frequency_max + ) + rates = np.array(self.rates) + + matrix = np.ones((model_axis.size, len(clp_label)), dtype=np.float64) + + if dataset_model.irf is None: + calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, model_axis) + elif isinstance(dataset_model.irf, IrfMultiGaussian): + global_dimension = dataset_model.get_global_dimension() + global_axis = dataset_model.get_global_axis() + global_index = indices.get(global_dimension) + centers, widths, scales, shift, _, _ = dataset_model.irf.parameter( + global_index, global_axis + ) + for center, width, scale in zip(centers, widths, scales): + matrix += calculate_damped_oscillation_matrix_gaussian_irf( + frequencies, + rates, + model_axis, + center, + width, + shift, + scale, + ) + matrix /= np.sum(scales) + + return clp_label, matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return ( + isinstance(dataset_model.irf, IrfMultiGaussian) + and dataset_model.irf.is_index_dependent() + ) + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + if is_full_model: + return + + megacomplexes = ( + dataset_model.global_megacomplex if is_full_model else dataset_model.megacomplex + ) + unique = len([m for m in megacomplexes if isinstance(m, DampedOscillationMegacomplex)]) < 2 + + prefix = "damped_oscillation" if unique else f"{self.label}_damped_oscillation" + + dataset.coords[f"{prefix}"] = self.labels + dataset.coords[f"{prefix}_frequency"] = (prefix, self.frequencies) + dataset.coords[f"{prefix}_rate"] = (prefix, self.rates) + + dim1 = dataset_model.get_global_axis().size + dim2 = len(self.labels) + doas = np.zeros((dim1, dim2), dtype=np.float64) + phase = np.zeros((dim1, dim2), dtype=np.float64) + for i, label in enumerate(self.labels): + sin = dataset.clp.sel(clp_label=f"{label}_sin") + cos = dataset.clp.sel(clp_label=f"{label}_cos") + doas[:, i] = np.sqrt(sin * sin + cos * cos) + phase[:, i] = np.unwrap(np.arctan2(sin, cos)) + + dataset[f"{prefix}_associated_spectra"] = ( + (dataset_model.get_global_dimension(), prefix), + doas, + ) + + dataset[f"{prefix}_phase"] = ( + (dataset_model.get_global_dimension(), prefix), + phase, + ) + + if not is_full_model: + if self.index_dependent(dataset_model): + dataset[f"{prefix}_sin"] = ( + ( + dataset_model.get_global_dimension(), + dataset_model.get_model_dimension(), + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values, + ) + + dataset[f"{prefix}_cos"] = ( + ( + dataset_model.get_global_dimension(), + dataset_model.get_model_dimension(), + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values, + ) + else: + dataset[f"{prefix}_sin"] = ( + (dataset_model.get_model_dimension(), prefix), + dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).values, + ) + + dataset[f"{prefix}_cos"] = ( + (dataset_model.get_model_dimension(), prefix), + dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).values, + ) + + +@nb.jit(nopython=True, parallel=True) +def calculate_damped_oscillation_matrix_no_irf(matrix, frequencies, rates, axis): + + idx = 0 + for frequency, rate in zip(frequencies, rates): + osc = np.exp(-rate * axis - 1j * frequency * axis) + matrix[:, idx] = osc.real + matrix[:, idx + 1] = osc.imag + idx += 2 + + +def calculate_damped_oscillation_matrix_gaussian_irf( + frequencies: np.ndarray, + rates: np.ndarray, + model_axis: np.ndarray, + center: float, + width: float, + shift: float, + scale: float, +): + shifted_axis = model_axis - center - shift + d = width ** 2 + k = rates + 1j * frequencies + dk = k * d + sqwidth = np.sqrt(2) * width + a = (-1 * shifted_axis[:, None] + 0.5 * dk) * k + a = np.minimum(a, 709) + a = np.exp(a) + b = 1 + erf((shifted_axis[:, None] - dk) / sqwidth) + osc = a * b * scale + return np.concatenate((osc.real, osc.imag), axis=1) diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py new file mode 100755 index 000000000..fe1c29927 --- /dev/null +++ b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py @@ -0,0 +1,409 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from glotaran.analysis.optimize import optimize +from glotaran.analysis.simulation import simulate +from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + + +class DampedOscillationsModel(Model): + @classmethod + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "damped_oscillation": DampedOscillationMegacomplex, + "decay": DecayMegacomplex, + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) + return super().from_dict( + model_dict, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, + ) + + +class OneOscillation: + sim_model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m2": { + "type": "spectral", + "shape": {"osc1_cos": "sh1", "osc1_sin": "sh1"}, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + }, + "dataset": {"dataset1": {"megacomplex": ["m1"], "global_megacomplex": ["m2"]}}, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "dataset": {"dataset1": {"megacomplex": ["m1"]}}, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25.5], + ["rate", 0.1], + ], + "shapes": {"amps": [7], "locs": [5], "width": [4]}, + } + ) + + parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 20], + ["rate", 0.3], + ], + } + ) + + time = np.arange(0, 3, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin"] + wanted_shape = (300, 2) + + +class OneOscillationWithIrf: + sim_model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m2": { + "type": "spectral", + "shape": {"osc1_cos": "sh1", "osc1_sin": "sh1"}, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "global_megacomplex": ["m2"], + "irf": "irf1", + } + }, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "megacomplex": { + "m1": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "irf": "irf1", + } + }, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "shapes": {"amps": [7], "locs": [5], "width": [4]}, + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + parameter = ParameterGroup.from_dict( + { + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + time = np.arange(0, 3, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin"] + wanted_shape = (300, 2) + + +class OneOscillationWithSequentialModel: + sim_model = DampedOscillationsModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m3": { + "type": "spectral", + "shape": { + "osc1_cos": "sh1", + "osc1_sin": "sh1", + "s1": "sh2", + "s2": "sh3", + }, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + "sh2": { + "type": "gaussian", + "amplitude": "shapes.amps.2", + "location": "shapes.locs.2", + "width": "shapes.width.2", + }, + "sh3": { + "type": "gaussian", + "amplitude": "shapes.amps.3", + "location": "shapes.locs.3", + "width": "shapes.width.3", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m1", "m2"], + "global_megacomplex": ["m3"], + } + }, + } + ) + + model = DampedOscillationsModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "damped_oscillation", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m1", "m2"], + } + }, + } + ) + + wanted_parameter = ParameterGroup.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.2], + ["2", 0.01], + ], + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "shapes": {"amps": [0.07, 2, 4], "locs": [5, 2, 8], "width": [4, 2, 3]}, + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + parameter = ParameterGroup.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.2], + ["2", 0.01], + ], + "osc": [ + ["freq", 25], + ["rate", 0.1], + ], + "irf": [["center", 0.3], ["width", 0.1]], + } + ) + + time = np.arange(-1, 5, 0.01) + spectral = np.arange(0, 10) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin", "s1", "s2"] + wanted_shape = (600, 4) + + +@pytest.mark.parametrize( + "suite", + [ + OneOscillation, + OneOscillationWithIrf, + OneOscillationWithSequentialModel, + ], +) +def test_doas_model(suite): + + print(suite.sim_model.validate()) # noqa + assert suite.sim_model.valid() + + print(suite.model.validate()) # noqa + assert suite.model.valid() + + print(suite.sim_model.validate(suite.wanted_parameter)) # noqa + assert suite.sim_model.valid(suite.wanted_parameter) + + print(suite.model.validate(suite.parameter)) # noqa + assert suite.model.valid(suite.parameter) + + dataset = simulate(suite.sim_model, "dataset1", suite.wanted_parameter, suite.axis) + print(dataset) # noqa + + assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) + + print(suite.parameter) # noqa + print(suite.wanted_parameter) # noqa + + data = {"dataset1": dataset} + scheme = Scheme( + model=suite.model, + parameters=suite.parameter, + data=data, + maximum_number_function_evaluations=20, + ) + result = optimize(scheme, raise_exception=True) + print(result.optimized_parameters) # noqa + + for label, param in result.optimized_parameters.all(): + assert np.allclose(param.value, suite.wanted_parameter.get(label).value, rtol=1e-1) + + resultdata = result.data["dataset1"] + assert np.array_equal(dataset["time"], resultdata["time"]) + assert np.array_equal(dataset["spectral"], resultdata["spectral"]) + assert dataset.data.shape == resultdata.fitted_data.shape + assert np.allclose(dataset.data, resultdata.fitted_data) + + assert "damped_oscillation_cos" in resultdata + assert "damped_oscillation_sin" in resultdata + assert "damped_oscillation_associated_spectra" in resultdata + assert "damped_oscillation_phase" in resultdata diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index e57f31f5e..cbc55b165 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -9,7 +9,6 @@ from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration from glotaran.builtin.megacomplexes.decay.irf import Irf from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian -from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix from glotaran.builtin.megacomplexes.decay.util import decay_matrix_implementation from glotaran.builtin.megacomplexes.decay.util import retrieve_decay_associated_data @@ -57,10 +56,8 @@ def involved_compartments(self): def index_dependent(self, dataset_model: DatasetModel) -> bool: return ( - isinstance(dataset_model.irf, IrfSpectralMultiGaussian) - and dataset_model.irf.dispersion_center is not None - ) or ( - isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.shift is not None + isinstance(dataset_model.irf, IrfMultiGaussian) + and dataset_model.irf.is_index_dependent() ) def calculate_matrix( diff --git a/glotaran/builtin/megacomplexes/decay/irf.py b/glotaran/builtin/megacomplexes/decay/irf.py index e1297c74f..7ddbc17f6 100644 --- a/glotaran/builtin/megacomplexes/decay/irf.py +++ b/glotaran/builtin/megacomplexes/decay/irf.py @@ -106,6 +106,9 @@ def calculate(self, index: int, global_axis: np.ndarray, model_axis: np.ndarray) for center, width, scale in zip(centers, widths, scales) ) + def is_index_dependent(self): + return self.shift is not None + @model_item( properties={ @@ -191,6 +194,9 @@ def calculate_dispersion(self, axis): dispersion.append(center) return np.asarray(dispersion).T + def is_index_dependent(self): + return super().is_index_dependent() or self.dispersion_center is not None + @model_item( properties={ diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index 32e266128..f788a6b90 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pytest import xarray as xr @@ -5,6 +7,7 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -22,12 +25,22 @@ def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): class DecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "decay": DecayMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "decay": DecayMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index b984fe4cd..c51a3ce67 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -173,7 +173,11 @@ def retrieve_decay_associated_data( das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T - component_coords = {"rate": ("component", rates), "lifetime": ("component", lifetimes)} + component_coords = { + "component": np.arange(rates.size), + "rate": ("component", rates), + "lifetime": ("component", lifetimes), + } das_coords = component_coords.copy() das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}" diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index ce559a31d..4b5659943 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pytest import xarray as xr @@ -7,6 +9,7 @@ from glotaran.analysis.util import calculate_matrix from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import DecayModel from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -14,12 +17,22 @@ class SpectralModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "spectral": SpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/glotaran/model/model.py b/glotaran/model/model.py index 60ec800b0..13c7cf24c 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -56,7 +56,7 @@ def __init__( @classmethod def from_dict( cls, - model_dict_ref: dict, + model_dict: dict, *, megacomplex_types: dict[str, type[Megacomplex]], default_megacomplex_type: str | None = None, @@ -73,10 +73,10 @@ def from_dict( megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type ) - model_dict = copy.deepcopy(model_dict_ref) + model_dict_local = copy.deepcopy(model_dict) # TODO: maybe redundant? # iterate over items - for name, items in list(model_dict.items()): + for name, items in list(model_dict_local.items()): if name not in model._model_items: warn(f"Unknown model item type '{name}'.") diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index 5b055aedc..a9a870555 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -1,6 +1,5 @@ -# To add a new cell, type '# %%' -# To add a new markdown cell, type '# %% [markdown]' -# %% +from __future__ import annotations + import importlib from collections import namedtuple from copy import deepcopy @@ -12,6 +11,7 @@ from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.io import prepare_time_trace_dataset +from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme @@ -28,13 +28,23 @@ class SpectralDecayModel(Model): @classmethod - def from_dict(cls, model_dict): + def from_dict( + cls, + model_dict, + *, + megacomplex_types: dict[str, type[Megacomplex]] | None = None, + default_megacomplex_type: str | None = None, + ): + defaults: dict[str, type[Megacomplex]] = { + "decay": DecayMegacomplex, + "spectral": SpectralMegacomplex, + } + if megacomplex_types is not None: + defaults.update(megacomplex_types) return super().from_dict( model_dict, - megacomplex_types={ - "decay": DecayMegacomplex, - "spectral": SpectralMegacomplex, - }, + megacomplex_types=defaults, + default_megacomplex_type=default_megacomplex_type, ) diff --git a/setup.cfg b/setup.cfg index 9b3a44fbc..2d7b1e90b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,6 +61,7 @@ glotaran.plugins.data_io = glotaran.plugins.megacomplexes = baseline = glotaran.builtin.megacomplexes.baseline coherent_artifact = glotaran.builtin.megacomplexes.coherent_artifact + damped_oscillation = glotaran.builtin.megacomplexes.damped_oscillation decay = glotaran.builtin.megacomplexes.decay spectral = glotaran.builtin.megacomplexes.spectral glotaran.plugins.project_io = From a464ec32271b22450db1ffb51bef3a6e6fe5a252 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Fri, 20 Aug 2021 18:16:26 +0200 Subject: [PATCH 12/29] =?UTF-8?q?=F0=9F=94=A7=20Fix=20interrogate=20usage?= =?UTF-8?q?=20(#781)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🩹 🔧 Only run interrogate on glotaran folder * 👌 🚧 Raised interrogate threshold for 52% to 55% (current value 56.1%) --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d22f3a349..997e06f46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -104,7 +104,7 @@ repos: rev: 1.4.0 hooks: - id: interrogate - args: [-vv, --config=pyproject.toml] + args: [-vv, --config=pyproject.toml, glotaran] pass_filenames: false - repo: https://github.com/asottile/yesqa diff --git a/pyproject.toml b/pyproject.toml index 365b410df..678b78df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ remove_redundant_aliases = true [tool.interrogate] exclude = ["setup.py", "docs", "*test/*", "benchmark/*"] ignore-init-module = true -fail-under = 52 +fail-under = 55 [tool.nbqa.addopts] flake8 = [ From b960a980bbfb5594714aebda6e73ff9227c28dc1 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 22 Aug 2021 10:07:26 +0200 Subject: [PATCH 13/29] =?UTF-8?q?=F0=9F=93=9A=20Add=20docs=20for=20the=20C?= =?UTF-8?q?LI=20(#784)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/requirements.txt | 1 + docs/source/conf.py | 1 + docs/source/index.rst | 1 + docs/source/user_documentation/cli.rst | 6 ++++++ setup.cfg | 2 +- 5 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 docs/source/user_documentation/cli.rst diff --git a/docs/requirements.txt b/docs/requirements.txt index 90a9fb6d8..6f855720d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ # documentation dependencies Sphinx>=3.2.0 +sphinx-click>=3.0.1 sphinx-rtd-theme>=0.5.1 sphinx-copybutton>=0.3.0 myst-parser>=0.12.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index 20bfbc691..f0e5173be 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,6 +49,7 @@ "sphinx.ext.imgmath", "sphinx.ext.viewcode", "sphinx.ext.napoleon", + "sphinx_click", "nbsphinx", "sphinx_last_updated_by_git", "myst_parser", diff --git a/docs/source/index.rst b/docs/source/index.rst index abc76d155..5946daa15 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,6 +27,7 @@ Welcome to pyglotaran's documentation! user_documentation/parameter user_documentation/optimizing user_documentation/using_plugins + user_documentation/cli .. toctree:: :maxdepth: 1 diff --git a/docs/source/user_documentation/cli.rst b/docs/source/user_documentation/cli.rst new file mode 100644 index 000000000..ec39f2f67 --- /dev/null +++ b/docs/source/user_documentation/cli.rst @@ -0,0 +1,6 @@ +Command-line Interface +====================== + +.. click:: glotaran.cli:main + :prog: glotaran + :nested: full diff --git a/setup.cfg b/setup.cfg index 2d7b1e90b..977820588 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,7 +73,7 @@ glotaran.plugins.project_io = test = pytest [rstcheck] -ignore_directives = autoattribute,autoclass,autoexception,autofunction,automethod,automodule,highlight +ignore_directives = autoattribute,autoclass,autoexception,autofunction,automethod,automodule,highlight,click ignore_messages = xarraydoc [darglint] From 5e9eb2268b7027f20023271417f02d2f823805c0 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 22 Aug 2021 11:44:19 +0200 Subject: [PATCH 14/29] =?UTF-8?q?=F0=9F=9A=87=20Speedup=20PR=20benchmark?= =?UTF-8?q?=20(#785)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🚇 Only run base compare commit and last HEAD commit on PR benchmark * 👌 Update comment w/o deleting the last one (comment should have an 'edit' history) --- .github/workflows/pr_benchmark.yml | 3 ++- .github/workflows/pr_benchmark_reaction.yml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr_benchmark.yml b/.github/workflows/pr_benchmark.yml index f75019c17..e1b60598b 100644 --- a/.github/workflows/pr_benchmark.yml +++ b/.github/workflows/pr_benchmark.yml @@ -48,7 +48,8 @@ jobs: git remote add upstream https://github.com/glotaran/pyglotaran git fetch upstream pushd benchmark - asv run v0.4.0^..HEAD --machine gh_action + asv run v0.4.0^..v0.4.0 --machine gh_action + asv run HEAD^..HEAD --machine gh_action asv publish - name: Checkout benchmark result repo diff --git a/.github/workflows/pr_benchmark_reaction.yml b/.github/workflows/pr_benchmark_reaction.yml index 9f80dd229..eb51ba0f8 100644 --- a/.github/workflows/pr_benchmark_reaction.yml +++ b/.github/workflows/pr_benchmark_reaction.yml @@ -72,7 +72,6 @@ jobs: number: ${{ steps.bench_diff.outputs.pr_nr }} id: benchmark-comment message: ${{ steps.bench_diff.outputs.comment }} - recreate: true - name: Commit PR results env: From dd0ce806c8615600b0ec44e5ad71f7e9878a8c1e Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Wed, 25 Aug 2021 03:36:32 +0200 Subject: [PATCH 15/29] =?UTF-8?q?=F0=9F=A9=B9=20Update=20calls=20to=20clp?= =?UTF-8?q?=5Farea=5Fpenalties=20where=20deprecated=20func=20is=20called?= =?UTF-8?q?=20(#790)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The deprecated equal_area_penalties would be called in lieu of the new clp_area_penalties function The pyglotaran-examples have in their model still equal_area_penalties which is correctly swapped with clp_area_penalties but them model.equal_area_penalties was still called in some places instead of model.clp_area_penalties --- glotaran/model/clp_penalties.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py index c1c817fd4..2918d9b41 100644 --- a/glotaran/model/clp_penalties.py +++ b/glotaran/model/clp_penalties.py @@ -60,7 +60,7 @@ def applies(interval): def has_spectral_penalties(model: Model) -> bool: - return len(model.equal_area_penalties) != 0 + return len(model.clp_area_penalties) != 0 def apply_spectral_penalties( @@ -74,7 +74,7 @@ def apply_spectral_penalties( ) -> np.ndarray: penalties = [] - for penalty in model.equal_area_penalties: + for penalty in model.clp_area_penalties: penalty = penalty.fill(model, parameters) source_area = _get_area( From a6a5a510cc2e95b8d81b32413824a59093ad5522 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Wed, 25 Aug 2021 04:02:05 +0200 Subject: [PATCH 16/29] Improve ordering in k_matrix involved_compartments function (#788) Avoiding the use of set to improve ordering Refactor at least guarantees better than random ordering using list(set(x)) TODO: further improve ordering to use initial_concentrations ordering --- glotaran/builtin/megacomplexes/decay/k_matrix.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/glotaran/builtin/megacomplexes/decay/k_matrix.py b/glotaran/builtin/megacomplexes/decay/k_matrix.py index 71a4cda27..0188a3637 100644 --- a/glotaran/builtin/megacomplexes/decay/k_matrix.py +++ b/glotaran/builtin/megacomplexes/decay/k_matrix.py @@ -39,12 +39,16 @@ def empty(cls, label: str, compartments: list[str]) -> KMatrix: def involved_compartments(self) -> list[str]: """A list of all compartments in the Matrix.""" + # TODO: find a better way that preserves ordering as defined in initial_concentrations compartments = [] for index in self.matrix: - compartments.append(index[0]) - compartments.append(index[1]) + if index[0] not in compartments: + compartments.append(index[0]) + if index[1] not in compartments: + compartments.append(index[1]) - compartments = list(set(compartments)) + # Don't use set, it randomly reorders the compartments. + # compartments = list(set(compartments)) return compartments def combine(self, k_matrix: KMatrix) -> KMatrix: From b28179970cd7ab0d9b7e276c3d25ff7dae69e7cd Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Fri, 27 Aug 2021 01:35:39 +0200 Subject: [PATCH 17/29] =?UTF-8?q?=F0=9F=A9=B9=20Fix=20and=20re-enable=20IR?= =?UTF-8?q?F=20Dispersion=20Test=20(#786)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix and re-enable IRF Dispersion Test Improve the tests in test_spectral_irf Added NoDispersion test Re-enable IrfDispersion tests on the CI (de-Volkswagen the tests) Improve assignment of center_dispersion to result dataset * Rename center_dispersion to irf_center_location in result The name irf_center_location in the results describes more clearly what the variable holds. Deprecated irf model item center_dispersion for center_dispersion_coefficients Deprecated irf model item width_dispersion for width_dispersion_coefficients --- .github/workflows/CI_CD_actions.yml | 2 +- .../builtin/io/yml/test/test_model_parser.py | 4 +- .../builtin/io/yml/test/test_model_spec.yml | 35 ++-- glotaran/builtin/megacomplexes/decay/irf.py | 28 +-- .../decay/test/test_spectral_irf.py | 163 ++++++++++++++---- glotaran/builtin/megacomplexes/decay/util.py | 13 +- .../deprecation/modules/builtin_io_yml.py | 21 +++ .../modules/test/test_builtin_io_yml.py | 28 +++ 8 files changed, 223 insertions(+), 71 deletions(-) diff --git a/.github/workflows/CI_CD_actions.yml b/.github/workflows/CI_CD_actions.yml index f7b17499a..f8123d88c 100644 --- a/.github/workflows/CI_CD_actions.yml +++ b/.github/workflows/CI_CD_actions.yml @@ -109,7 +109,7 @@ jobs: run: pip freeze - name: Run tests run: | - pytest --cov=./ --cov-report term --cov-report xml --cov-config pyproject.toml -k 'not IrfDispersion' glotaran + pytest --cov=./ --cov-report term --cov-report xml --cov-config pyproject.toml glotaran - name: Codecov Upload uses: codecov/codecov-action@v2 diff --git a/glotaran/builtin/io/yml/test/test_model_parser.py b/glotaran/builtin/io/yml/test/test_model_parser.py index 8fa942b55..be4ac0ce3 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -126,9 +126,9 @@ def test_irf(model): assert irf.width == want want = [3] if i == 1 else [5, 6] if i == 2: - assert irf.center_dispersion == want + assert irf.center_dispersion_coefficients == want want = [7, 8] - assert irf.width_dispersion == want + assert irf.width_dispersion_coefficients == want want = [9] assert irf.scale == want assert irf.normalize == (i == 1) diff --git a/glotaran/builtin/io/yml/test/test_model_spec.yml b/glotaran/builtin/io/yml/test/test_model_spec.yml index cf5dfce21..2c485fa82 100644 --- a/glotaran/builtin/io/yml/test/test_model_spec.yml +++ b/glotaran/builtin/io/yml/test/test_model_spec.yml @@ -1,6 +1,5 @@ default-megacomplex: decay - dataset: dataset1: megacomplex: [cmplx1] @@ -23,36 +22,36 @@ irf: irf2: type: spectral-gaussian center: [1, 2] - width: [3,4] + width: [3, 4] scale: [9] normalize: false backsweep: true backsweep_period: 55 dispersion_center: 55 - center_dispersion: [5,6] - width_dispersion: [7,8] + center_dispersion_coefficients: [5, 6] + width_dispersion_coefficients: [7, 8] model_dispersion_with_wavenumber: true initial_concentration: inputD1: - compartments: [s1,s2,s3] - parameters: [1,2,3] + compartments: [s1, s2, s3] + parameters: [1, 2, 3] inputD2: - compartments: [s1,s2,s3] - parameters: [1,2,3] + compartments: [s1, s2, s3] + parameters: [1, 2, 3] # Convention matrix notation column = source, row = target compartment # (2,1) means from 1 to 2 k_matrix: km1: matrix: - (s1, s1): '1' - (s2, s1): '2' - (s1, s2): '3' - (s3, s1): '4' - (s1, s3): '5' - (s4, s1): '6' - (s1, s4): '7' + (s1, s1): "1" + (s2, s1): "2" + (s1, s2): "3" + (s3, s1): "4" + (s1, s3): "5" + (s4, s1): "6" + (s1, s4): "7" shape: shape1: @@ -63,9 +62,9 @@ shape: megacomplex: cmplx1: - k_matrix: [km1] # A megacomplex has one or more k-matrices + k_matrix: [km1] # A megacomplex has one or more k-matrices cmplx2: - k_matrix: [km2] + k_matrix: [km2] cmplx3: type: "spectral" shape: @@ -97,7 +96,7 @@ relations: - source: s1 target: s2 parameter: 8 - interval: [[1,100], [2,200]] + interval: [[1, 100], [2, 200]] weights: - datasets: [d1, d2] diff --git a/glotaran/builtin/megacomplexes/decay/irf.py b/glotaran/builtin/megacomplexes/decay/irf.py index 7ddbc17f6..14d7d2d78 100644 --- a/glotaran/builtin/megacomplexes/decay/irf.py +++ b/glotaran/builtin/megacomplexes/decay/irf.py @@ -47,10 +47,10 @@ class IrfMultiGaussian: one or more center of the irf as parameter indices width: one or more widths of the gaussian as parameter index - center_dispersion: + center_dispersion_coefficients: polynomial coefficients for the dispersion of the center as list of parameter indices. None for no dispersion. - width_dispersion: + width_dispersion_coefficients: polynomial coefficients for the dispersion of the width as parameter indices. None for no dispersion. @@ -124,8 +124,8 @@ class IrfGaussian(IrfMultiGaussian): @model_item( properties={ "dispersion_center": {"type": Parameter, "allow_none": True}, - "center_dispersion": {"type": List[Parameter], "default": []}, - "width_dispersion": {"type": List[Parameter], "default": []}, + "center_dispersion_coefficients": {"type": List[Parameter], "default": []}, + "width_dispersion_coefficients": {"type": List[Parameter], "default": []}, "model_dispersion_with_wavenumber": {"type": bool, "default": False}, }, has_type=True, @@ -149,12 +149,12 @@ class IrfSpectralMultiGaussian(IrfMultiGaussian): one or more center of the irf as parameter indices width: one or more widths of the gaussian as parameter index - center_dispersion: - polynomial coefficients for the dispersion of the - center as list of parameter indices. None for no dispersion. - width_dispersion: - polynomial coefficients for the dispersion of the - width as parameter indices. None for no dispersion. + center_dispersion_coefficients: + list of parameters with polynomial coefficients describing + the dispersion of the irf center location. None for no dispersion. + width_dispersion_coefficients: + list of parameters with polynomial coefficients describing + the dispersion of the width of the irf. None for no dispersion. """ @@ -173,16 +173,16 @@ def parameter(self, global_index: int, global_axis: np.ndarray): else (index - self.dispersion_center) / 100 ) - if len(self.center_dispersion) != 0: + if len(self.center_dispersion_coefficients) != 0: if self.dispersion_center is None: raise ModelError(f"No dispersion center defined for irf '{self.label}'") - for i, disp in enumerate(self.center_dispersion): + for i, disp in enumerate(self.center_dispersion_coefficients): centers += disp * np.power(dist, i + 1) - if len(self.width_dispersion) != 0: + if len(self.width_dispersion_coefficients) != 0: if self.dispersion_center is None: raise ModelError(f"No dispersion center defined for irf '{self.label}'") - for i, disp in enumerate(self.width_dispersion): + for i, disp in enumerate(self.width_dispersion_coefficients): widths = widths + disp * np.power(dist, i + 1) return centers, widths, scale, shift, backsweep, backsweep_period diff --git a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py index 4256cbfa2..d75d636c2 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py @@ -1,4 +1,6 @@ +import warnings from copy import deepcopy +from textwrap import dedent import numpy as np import pytest @@ -35,6 +37,14 @@ sh1: type: one """ +MODEL_NO_IRF_DISPERSION = f"""\ +{MODEL_BASE} +irf: + irf1: + type: spectral-gaussian + center: irf.center + width: irf.width +""" MODEL_SIMPLE_IRF_DISPERSION = f"""\ {MODEL_BASE} irf: @@ -43,7 +53,7 @@ center: irf.center width: irf.width dispersion_center: irf.dispersion_center - center_dispersion: [irf.center_dispersion] + center_dispersion_coefficients: [irf.cdc1] """ MODEL_MULTI_IRF_DISPERSION = f"""\ {MODEL_BASE} @@ -53,8 +63,19 @@ center: [irf.center] width: [irf.width] dispersion_center: irf.dispersion_center - center_dispersion: [irf.center_dispersion1, irf.center_dispersion2] - width_dispersion: [irf.width_dispersion] + center_dispersion_coefficients: [irf.cdc1, irf.cdc2] + width_dispersion_coefficients: [irf.wdc1] +""" + +MODEL_MULTIPULSE_IRF_DISPERSION = f"""\ +{MODEL_BASE} +irf: + irf1: + type: spectral-multi-gaussian + center: [irf.center1, irf.center2] + width: [irf.width] + dispersion_center: irf.dispersion_center + center_dispersion_coefficients: [irf.cdc1, irf.cdc2, irf.cdc3] """ PARAMETERS_BASE = """\ @@ -64,62 +85,111 @@ - ['1', 0.5, {'non-negative': False}] """ +PARAMETERS_NO_IRF_DISPERSION = f"""\ +{PARAMETERS_BASE} +irf: + - ['center', 0.3] + - ['width', 0.1] +""" + PARAMETERS_SIMPLE_IRF_DISPERSION = f"""\ {PARAMETERS_BASE} irf: - ['center', 0.3] - ['width', 0.1] - ['dispersion_center', 400, {{'vary': False}}] - - ['center_dispersion', 0.5] + - ['cdc1', 0.5] """ +# What is this? PARAMETERS_MULTI_IRF_DISPERSION = f"""\ {PARAMETERS_BASE} irf: - ["center", 0.3] - ["width", 0.1] - ["dispersion_center", 400, {{"vary": False}}] - - ["center_dispersion1", 0.01] - - ["center_dispersion2", 0.001] - - ["width_dispersion", 0.025] + - ["cdc1", 0.1] + - ["cdc2", 0.01] + - ["wdc1", 0.025] """ +PARAMETERS_MULTIPULSE_IRF_DISPERSION = f"""\ +{PARAMETERS_BASE} +irf: + - ["center1", 0.3] + - ["center2", 0.4] + - ['width', 0.1] + - ['dispersion_center', 400, {{'vary': False}}] + - ["cdc1", 0.5] + - ["cdc2", 0.1] + - ["cdc3", -0.01] +""" + + +def _time_axis(): + time_p1 = np.linspace(-1, 1, 20, endpoint=False) + time_p2 = np.linspace(1, 2, 10, endpoint=False) + time_p3 = np.geomspace(2, 20, num=20) + return np.array(np.concatenate([time_p1, time_p2, time_p3])) + + +def _spectral_axis(): + return np.linspace(300, 500, 3) + + +def _calculate_irf_position( + index, center, dispersion_center=None, center_dispersion_coefficients=None +): + if center_dispersion_coefficients is None: + center_dispersion_coefficients = [] + if dispersion_center is not None: + distance = (index - dispersion_center) / 100 + if dispersion_center is not None: + for i, coefficient in enumerate(center_dispersion_coefficients): + center += coefficient * np.power(distance, i + 1) + return center + + +class NoIrfDispersion: + model = load_model(MODEL_NO_IRF_DISPERSION, format_name="yml_str") + parameters = load_parameters(PARAMETERS_NO_IRF_DISPERSION, format_name="yml_str") + axis = {"time": _time_axis(), "spectral": _spectral_axis()} + class SimpleIrfDispersion: model = load_model(MODEL_SIMPLE_IRF_DISPERSION, format_name="yml_str") parameters = load_parameters(PARAMETERS_SIMPLE_IRF_DISPERSION, format_name="yml_str") - time_p1 = np.linspace(-1, 2, 50, endpoint=False) - time_p2 = np.linspace(2, 5, 30, endpoint=False) - time_p3 = np.geomspace(5, 10, num=20) - time = np.array(np.concatenate([time_p1, time_p2, time_p3])) - spectral = np.arange(300, 500, 100) - axis = {"time": time, "spectral": spectral} + axis = {"time": _time_axis(), "spectral": _spectral_axis()} class MultiIrfDispersion: model = load_model(MODEL_MULTI_IRF_DISPERSION, format_name="yml_str") parameters = load_parameters(PARAMETERS_MULTI_IRF_DISPERSION, format_name="yml_str") - time = np.arange(-1, 5, 0.2) - spectral = np.arange(300, 500, 100) - axis = {"time": time, "spectral": spectral} + axis = {"time": _time_axis(), "spectral": _spectral_axis()} + + +class MultiCenterIrfDispersion: + model = load_model(MODEL_MULTIPULSE_IRF_DISPERSION, format_name="yml_str") + parameters = load_parameters(PARAMETERS_MULTIPULSE_IRF_DISPERSION, format_name="yml_str") + axis = {"time": _time_axis(), "spectral": _spectral_axis()} @pytest.mark.parametrize( "suite", [ + NoIrfDispersion, SimpleIrfDispersion, MultiIrfDispersion, + MultiCenterIrfDispersion, ], ) def test_spectral_irf(suite): model = suite.model - print(model.validate()) - assert model.valid() + assert model.valid(), model.validate() parameters = suite.parameters - print(model.validate(parameters)) - assert model.valid(parameters) + assert model.valid(parameters), model.validate(parameters) sim_model = deepcopy(model) sim_model.dataset["dataset1"].global_megacomplex = ["mc2"] @@ -136,27 +206,62 @@ def test_spectral_irf(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme) - print(result.optimized_parameters) + # print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) resultdata = result.data["dataset1"] - print(resultdata) - + # print(resultdata) assert np.array_equal(dataset["time"], resultdata["time"]) assert np.array_equal(dataset["spectral"], resultdata["spectral"]) assert dataset.data.shape == resultdata.data.shape assert dataset.data.shape == resultdata.fitted_data.shape - assert np.allclose(dataset.data, resultdata.fitted_data, atol=1e-14) + # assert np.allclose(dataset.data, resultdata.fitted_data, atol=1e-14) + + fit_data_max_at_start = resultdata.fitted_data.isel(spectral=0).argmax(axis=0) + fit_data_max_at_end = resultdata.fitted_data.isel(spectral=-1).argmax(axis=0) + + if suite is NoIrfDispersion: + assert "center_dispersion_1" not in resultdata + assert fit_data_max_at_start == fit_data_max_at_end + else: + assert "center_dispersion_1" in resultdata + assert fit_data_max_at_start != fit_data_max_at_end + if abs(fit_data_max_at_start - fit_data_max_at_end) < 3: + warnings.warn( + dedent( + """ + Bad test, one of the following could be the case: + - dispersion too small + - spectral window to small + - time resolution (around the maximum of the IRF) too low" + """ + ) + ) - irf_max_at_start = resultdata.fitted_data.isel(spectral=0).argmax(axis=0) - irf_max_at_end = resultdata.fitted_data.isel(spectral=-1).argmax(axis=0) - print(f" irf_max_at_start: {irf_max_at_start}\n irf_max_at_end: {irf_max_at_end}") - # These should not be equal due to dispersion: - assert irf_max_at_start != irf_max_at_end + for x in suite.axis["spectral"]: + # calculated irf location + model_irf_center = suite.model.irf["irf1"].center + model_dispersion_center = suite.model.irf["irf1"].dispersion_center + model_center_dispersion_coefficients = suite.model.irf[ + "irf1" + ].center_dispersion_coefficients + calc_irf_location_at_x = _calculate_irf_position( + x, model_irf_center, model_dispersion_center, model_center_dispersion_coefficients + ) + # fitted irf location + fitted_irf_loc_at_x = resultdata["irf_center_location"].sel(spectral=x) + assert np.allclose(calc_irf_location_at_x, fitted_irf_loc_at_x) assert "species_associated_spectra" in resultdata assert "decay_associated_spectra" in resultdata assert "irf_center" in resultdata + + +if __name__ == "__main__": + test_spectral_irf(NoIrfDispersion) + test_spectral_irf(SimpleIrfDispersion) + test_spectral_irf(MultiIrfDispersion) + test_spectral_irf(MultiCenterIrfDispersion) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index c51a3ce67..869d9d39c 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -228,10 +228,9 @@ def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimens dataset["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] dataset["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: - for i, dispersion in enumerate( - irf.calculate_dispersion(dataset.coords["spectral"].values) - ): - dataset[f"center_dispersion_{i+1}"] = ( - global_dimension, - dispersion, - ) + dataset["irf_center_location"] = ( + ("irf_nr", global_dimension), + irf.calculate_dispersion(dataset.coords["spectral"].values), + ) + # center_dispersion_1 for backwards compatibility (0.3-0.4.1) + dataset["center_dispersion_1"] = dataset["irf_center_location"].sel(irf_nr=0) diff --git a/glotaran/deprecation/modules/builtin_io_yml.py b/glotaran/deprecation/modules/builtin_io_yml.py index 635a5a9a9..0a3aa99dc 100644 --- a/glotaran/deprecation/modules/builtin_io_yml.py +++ b/glotaran/deprecation/modules/builtin_io_yml.py @@ -85,3 +85,24 @@ def model_spec_deprecations(spec: MutableMapping[Any, Any]) -> None: swap_keys=("equal_area_penalties", "clp_area_penalties"), stacklevel=load_model_stack_level, ) + + if "irf" in spec: + for _, irf in spec["irf"].items(): + deprecate_dict_entry( + dict_to_check=irf, + deprecated_usage="center_dispersion", + new_usage="center_dispersion_coefficients", + to_be_removed_in_version="0.7.0", + swap_keys=("center_dispersion", "center_dispersion_coefficients"), + stacklevel=load_model_stack_level, + ) + + for _, irf in spec["irf"].items(): + deprecate_dict_entry( + dict_to_check=irf, + deprecated_usage="width_dispersion", + new_usage="width_dispersion_coefficients", + to_be_removed_in_version="0.7.0", + swap_keys=("width_dispersion", "width_dispersion_coefficients"), + stacklevel=load_model_stack_level, + ) diff --git a/glotaran/deprecation/modules/test/test_builtin_io_yml.py b/glotaran/deprecation/modules/test/test_builtin_io_yml.py index 26202b947..8cf03f510 100644 --- a/glotaran/deprecation/modules/test/test_builtin_io_yml.py +++ b/glotaran/deprecation/modules/test/test_builtin_io_yml.py @@ -55,6 +55,32 @@ "clp_area_penalties", [{"type": "equal_area"}], ), + ( + dedent( + """ + irf: + irf1: + center_dispersion: [cdc1] + + """ + ), + 1, + "irf", + {"irf1": {"center_dispersion_coefficients": ["cdc1"]}}, + ), + ( + dedent( + """ + irf: + irf1: + "width_dispersion": [wdc1] + + """ + ), + 1, + "irf", + {"irf1": {"width_dispersion_coefficients": ["wdc1"]}}, + ), ), ids=( "type: kinetic-spectrum", @@ -62,6 +88,8 @@ "spectral_relations", "spectral_constraints", "equal_area_penalties", + "center_dispersion", + "width_dispersion", ), ) def test_model_spec_deprecations( From 6689a9561561b3290f7da2855d754ac972bd7f13 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 29 Aug 2021 16:17:02 +0200 Subject: [PATCH 18/29] =?UTF-8?q?=F0=9F=A9=B9=20Fixed=20wrong=20value=20fo?= =?UTF-8?q?r=20model=20spec=20deprecation=20and=20spectral=20model=20(#795?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The spectral model (type) freshly added in v0.4 was then called spectral-model. --- glotaran/deprecation/modules/builtin_io_yml.py | 4 ++-- glotaran/deprecation/modules/test/test_builtin_io_yml.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/glotaran/deprecation/modules/builtin_io_yml.py b/glotaran/deprecation/modules/builtin_io_yml.py index 0a3aa99dc..0bc5c59df 100644 --- a/glotaran/deprecation/modules/builtin_io_yml.py +++ b/glotaran/deprecation/modules/builtin_io_yml.py @@ -30,10 +30,10 @@ def model_spec_deprecations(spec: MutableMapping[Any, Any]) -> None: deprecate_dict_entry( dict_to_check=spec, - deprecated_usage="type: spectrum", + deprecated_usage="type: spectral-model", new_usage="default-megacomplex: spectral", to_be_removed_in_version="0.7.0", - replace_rules=({"type": "spectrum"}, {"default-megacomplex": "spectral"}), + replace_rules=({"type": "spectral-model"}, {"default-megacomplex": "spectral"}), stacklevel=load_model_stack_level, ) diff --git a/glotaran/deprecation/modules/test/test_builtin_io_yml.py b/glotaran/deprecation/modules/test/test_builtin_io_yml.py index 8cf03f510..873230366 100644 --- a/glotaran/deprecation/modules/test/test_builtin_io_yml.py +++ b/glotaran/deprecation/modules/test/test_builtin_io_yml.py @@ -19,7 +19,7 @@ "model_yml_str, expected_nr_of_warnings, expected_key, expected_value", ( ("type: kinetic-spectrum", 1, "default-megacomplex", "decay"), - ("type: spectrum", 1, "default-megacomplex", "spectral"), + ("type: spectral-model", 1, "default-megacomplex", "spectral"), ( dedent( """ From 9dc64affda014d31e05baa5e864528373980e258 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Sun, 29 Aug 2021 21:12:15 +0200 Subject: [PATCH 19/29] Cleanup IRF Dispersion test PR #786 Address a trivial Sourcery concern to avoid further PRs Reset low relative tolerance back to default when comparing parameter values Added some helpful hints to comparison asserts in case of failure --- .../decay/test/test_spectral_irf.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py index d75d636c2..2bbed40d0 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py @@ -144,7 +144,6 @@ def _calculate_irf_position( center_dispersion_coefficients = [] if dispersion_center is not None: distance = (index - dispersion_center) / 100 - if dispersion_center is not None: for i, coefficient in enumerate(center_dispersion_coefficients): center += coefficient * np.power(distance, i + 1) return center @@ -206,10 +205,14 @@ def test_spectral_irf(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme) - # print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, parameters.get(label).value), dedent( + f""" + Error in {suite.__name__} comparing {param.full_label}, + - diff={param.value-parameters.get(label).value} + """ + ) resultdata = result.data["dataset1"] @@ -253,15 +256,13 @@ def test_spectral_irf(suite): ) # fitted irf location fitted_irf_loc_at_x = resultdata["irf_center_location"].sel(spectral=x) - assert np.allclose(calc_irf_location_at_x, fitted_irf_loc_at_x) + assert np.allclose(calc_irf_location_at_x, fitted_irf_loc_at_x.values), dedent( + f""" + Error in {suite.__name__} comparing irf_center_location, + - diff={calc_irf_location_at_x-fitted_irf_loc_at_x.values} + """ + ) assert "species_associated_spectra" in resultdata assert "decay_associated_spectra" in resultdata assert "irf_center" in resultdata - - -if __name__ == "__main__": - test_spectral_irf(NoIrfDispersion) - test_spectral_irf(SimpleIrfDispersion) - test_spectral_irf(MultiIrfDispersion) - test_spectral_irf(MultiCenterIrfDispersion) From 70a12053f8718f489fb4bc0cda240bd5ae0876ae Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Tue, 31 Aug 2021 20:23:05 +0200 Subject: [PATCH 20/29] Fix compartment ordering randomization due to use of set (#799) Get rid of all things set() This solves the issue whereby sometimes in examination with multiple datasets the compartment ordering would be randomized. --- glotaran/analysis/problem_grouped.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index 5d0b0041e..7711f50ab 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -263,9 +263,12 @@ def calculate_index_independent_matrices( self._reduced_matrices[group_label] = combine_matrices( [self._reduced_matrices[label] for label in group] ) - self._group_clp_labels[group_label] = list( - set(itertools.chain(*(self._matrices[label].clp_labels for label in group))) - ) + group_clp_labels = [] + for label in group: + for clp_label in self._matrices[label].clp_labels: + if clp_label not in group_clp_labels: + group_clp_labels.append(clp_label) + self._group_clp_labels[group_label] = group_clp_labels return self._matrices, self._reduced_matrices From fe7d75c70e184186a779ace7187c451789009ad6 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Fri, 3 Sep 2021 22:31:33 +0200 Subject: [PATCH 21/29] =?UTF-8?q?=F0=9F=A9=B9=20False=20positive=20model?= =?UTF-8?q?=20validation=20fail=20when=20combining=20multiple=20default=20?= =?UTF-8?q?megacomplexes=20(#797)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🩹 Fixed missing 'f' in front of f-string which rendered the error useless * 🩹 Deactivate instance check if megacomplex type is None Compatibility with legacy (0.4.0) model specs and combining multiple megacomplexes * 🩹 Fixed usage of 'glotaran_unique' in 'ensure_unique_megacomplexes' * 👌 Improved typing * 🧪 Added unittests for 'ensure_unique_megacomplexes' * 🩹 Fixed missing megacomplex definition not being reported as error * 🩹 Fixed megacomplex type being None * ♻️ Refactored 'ensure_unique_megacomplexes' using Counter and improved typing. Since 'DatasetModel' is never instantiated from the class definition, but 'create_dataset_model_type' creates a new class in a closure each time a new instance is needed, values across instances won't be overwritten. * ♻️ Moved typing information to stub file After a personal request of @joernweissenborn I removed the class attributes again and moved them to a pyi file * 🧹 Removed warning about missing megacomplex definition Since it is already reported as 'Missing Model Item' --- glotaran/model/dataset_model.py | 78 ++++++++++++-------- glotaran/model/dataset_model.pyi | 47 ++++++++++++ glotaran/model/megacomplex.py | 1 + glotaran/model/test/test_dataset_model.py | 88 +++++++++++++++++++++++ 4 files changed, 184 insertions(+), 30 deletions(-) create mode 100644 glotaran/model/dataset_model.pyi create mode 100644 glotaran/model/test/test_dataset_model.py diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index d86a99ad3..2290ff3c9 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -1,22 +1,26 @@ """The DatasetModel class.""" from __future__ import annotations +from collections import Counter from typing import TYPE_CHECKING -from typing import Generator import numpy as np import xarray as xr from glotaran.model.item import model_item from glotaran.model.item import model_item_validator -from glotaran.parameter import Parameter if TYPE_CHECKING: + from typing import Any + from typing import Generator + from typing import Hashable + from glotaran.model.megacomplex import Megacomplex from glotaran.model.model import Model + from glotaran.parameter import Parameter -def create_dataset_model_type(properties: dict[str, any]) -> type: +def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: @model_item(properties=properties) class ModelDatasetModel(DatasetModel): pass @@ -33,13 +37,17 @@ class DatasetModel: parameter. """ - def iterate_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + def iterate_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's megacomplexes.""" for i, megacomplex in enumerate(self.megacomplex): scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None yield scale, megacomplex - def iterate_global_megacomplexes(self) -> Generator[tuple[Parameter | int, Megacomplex | str]]: + def iterate_global_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's global megacomplexes.""" for i, megacomplex in enumerate(self.global_megacomplex): scale = ( @@ -63,7 +71,7 @@ def get_model_dimension(self) -> str: ) return self._model_dimension - def finalize_data(self, dataset: xr.Dataset): + def finalize_data(self, dataset: xr.Dataset) -> None: is_full_model = self.has_global_model() for megacomplex in self.megacomplex: megacomplex.finalize_data(self, dataset, is_full_model=is_full_model) @@ -73,7 +81,7 @@ def finalize_data(self, dataset: xr.Dataset): self, dataset, is_full_model=is_full_model, as_global=True ) - def overwrite_model_dimension(self, model_dimension: str): + def overwrite_model_dimension(self, model_dimension: str) -> None: """Overwrites the dataset model's model dimension.""" self._model_dimension = model_dimension @@ -104,11 +112,11 @@ def get_global_dimension(self) -> str: ) return self._global_dimension - def overwrite_global_dimension(self, global_dimension: str): + def overwrite_global_dimension(self, global_dimension: str) -> None: """Overwrites the dataset model's global dimension.""" self._global_dimension = global_dimension - def swap_dimensions(self): + def swap_dimensions(self) -> None: """Swaps the dataset model's global and model dimension.""" global_dimension = self.get_model_dimension() model_dimension = self.get_global_dimension() @@ -117,9 +125,7 @@ def swap_dimensions(self): def set_data(self, dataset: xr.Dataset) -> DatasetModel: """Sets the dataset model's data.""" - self._coords: dict[str, np.ndarray] = { - name: dim.values for name, dim in dataset.coords.items() - } + self._coords = {name: dim.values for name, dim in dataset.coords.items()} self._data: np.ndarray = dataset.data.values self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None if self._weight is not None: @@ -152,7 +158,7 @@ def set_coordinates(self, coords: dict[str, np.ndarray]): """Sets the dataset model's coordinates.""" self._coords = coords - def get_coordinates(self) -> np.ndarray: + def get_coordinates(self) -> dict[Hashable, np.ndarray]: """Gets the dataset model's coordinates.""" return self._coords @@ -166,20 +172,32 @@ def get_global_axis(self) -> np.ndarray: @model_item_validator(False) def ensure_unique_megacomplexes(self, model: Model) -> list[str]: - - megacomplexes = [model.megacomplex[m] for m in self.megacomplex if m in model.megacomplex] - types = {type(m) for m in megacomplexes} - problems = [] - - for megacomplex_type in types: - if not megacomplex_type.glotaran_unique: - continue - instances = [m for m in megacomplexes if isinstance(m, megacomplex_type)] - n = len(instances) - if n != 1: - problems.append( - f"Multiple instances of unique megacomplex type '{instances[0].type}' " - "in dataset {self.label}" - ) - - return problems + """Ensure that unique megacomplexes Are only used once per dataset. + + Parameters + ---------- + model : Model + Model object using this dataset model. + + Returns + ------- + list[str] + Error messages to be shown when the model gets validated. + """ + glotaran_unique_megacomplex_types = [] + + for megacomplex_name in self.megacomplex: + try: + megacomplex_instance = model.megacomplex[megacomplex_name] + if type(megacomplex_instance).glotaran_unique() is True: + type_name = megacomplex_instance.type or megacomplex_instance.name + glotaran_unique_megacomplex_types.append(type_name) + except KeyError: + pass + + return [ + f"Multiple instances of unique megacomplex type {type_name!r} " + f"in dataset {self.label!r}" + for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common() + if count > 1 + ] diff --git a/glotaran/model/dataset_model.pyi b/glotaran/model/dataset_model.pyi new file mode 100644 index 000000000..c0a7b49f7 --- /dev/null +++ b/glotaran/model/dataset_model.pyi @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Any +from typing import Generator +from typing import Hashable + +import numpy as np +import xarray as xr + +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.model import Model +from glotaran.parameter import Parameter + +def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: ... + +class DatasetModel: + + label: str + megacomplex: list[str] + megacomplex_scale: list[Parameter] | None + global_megacomplex: list[str] + global_megacomplex_scale: list[Parameter] | None + scale: Parameter | None + _coords: dict[Hashable, np.ndarray] + def iterate_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... + def iterate_global_megacomplexes( + self, + ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... + def get_model_dimension(self) -> str: ... + def finalize_data(self, dataset: xr.Dataset) -> None: ... + def overwrite_model_dimension(self, model_dimension: str) -> None: ... + def get_global_dimension(self) -> str: ... + def overwrite_global_dimension(self, global_dimension: str) -> None: ... + def swap_dimensions(self) -> None: ... + def set_data(self, dataset: xr.Dataset) -> DatasetModel: ... + def get_data(self) -> np.ndarray: ... + def get_weight(self) -> np.ndarray | None: ... + def index_dependent(self) -> bool: ... + def overwrite_index_dependent(self, index_dependent: bool): ... + def has_global_model(self) -> bool: ... + def set_coordinates(self, coords: dict[str, np.ndarray]): ... + def get_coordinates(self) -> dict[Hashable, np.ndarray]: ... + def get_model_axis(self) -> np.ndarray: ... + def get_global_axis(self) -> np.ndarray: ... + def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ... diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 471dcf4ac..e8ae6afc1 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -70,6 +70,7 @@ def decorator(cls): megacomplex_type = model_item(properties=properties, has_type=True)(cls) if register_as is not None: + megacomplex_type.name = register_as register_megacomplex(register_as, megacomplex_type) return megacomplex_type diff --git a/glotaran/model/test/test_dataset_model.py b/glotaran/model/test/test_dataset_model.py new file mode 100644 index 000000000..ac89ab74d --- /dev/null +++ b/glotaran/model/test/test_dataset_model.py @@ -0,0 +1,88 @@ +"""Tests for glotaran.model.dataset_model.DatasetModel""" +from __future__ import annotations + +import pytest + +from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex +from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex +from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model.dataset_model import create_dataset_model_type +from glotaran.model.model import default_dataset_properties + + +class MockModel: + """Test Model only containing the megacomplex property. + + Multiple and different kinds of megacomplexes are defined + but only a subset will be used by the DatsetModel. + """ + + def __init__(self) -> None: + self.megacomplex = { + # not unique + "d1": DecayMegacomplex(), + "d2": DecayMegacomplex(), + "d3": DecayMegacomplex(), + "s1": SpectralMegacomplex(), + "s2": SpectralMegacomplex(), + "s3": SpectralMegacomplex(), + "doa1": DampedOscillationMegacomplex(), + "doa2": DampedOscillationMegacomplex(), + # unique + "b1": BaselineMegacomplex(), + "b2": BaselineMegacomplex(), + "c1": CoherentArtifactMegacomplex(), + "c2": CoherentArtifactMegacomplex(), + } + + +@pytest.mark.parametrize( + "used_megacomplexes, expected_problems", + ( + ( + ["d1"], + [], + ), + ( + ["d1", "d2", "d3"], + [], + ), + ( + ["s1", "s2", "s3"], + [], + ), + ( + ["d1", "d2", "d3", "s1", "s2", "s3", "doa1", "doa2", "b1", "c1"], + [], + ), + ( + ["d1", "b1", "b2"], + ["Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'"], + ), + ( + ["d1", "c1", "c2"], + ["Multiple instances of unique megacomplex type 'coherent-artifact' in dataset 'ds1'"], + ), + ( + ["d1", "b1", "b2", "c1", "c2"], + [ + "Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'", + "Multiple instances of unique megacomplex type " + "'coherent-artifact' in dataset 'ds1'", + ], + ), + ), +) +def test_datasetmodel_ensure_unique_megacomplexes( + used_megacomplexes: list[str], expected_problems: list[str] +): + """Only report problems if multiple unique megacomplexes of the same type are used.""" + dataset_model = create_dataset_model_type({**default_dataset_properties})() + dataset_model.megacomplex = used_megacomplexes # type:ignore + dataset_model.label = "ds1" # type:ignore + problems = dataset_model.ensure_unique_megacomplexes(MockModel()) # type:ignore + + assert len(problems) == len(expected_problems) + assert problems == expected_problems From 2a6b3e623822d1fedc5d198764cbbd5b500ca6f6 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Sat, 4 Sep 2021 02:56:12 +0200 Subject: [PATCH 22/29] Improvements to application of clp_penalties (equal area) (#801) * Calculate clp_penalties per datasets like in v0.4.1 Add reference to dataset_models to calculate_clp_penalties to calculate clp_penalties per dataset (axis) --- glotaran/analysis/problem_grouped.py | 7 ++- glotaran/analysis/problem_ungrouped.py | 7 ++- glotaran/analysis/util.py | 61 ++++++++++++++++++-------- glotaran/model/clp_penalties.py | 1 + 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index 7711f50ab..e898970b2 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -293,7 +293,12 @@ def calculate_residual(self): self._weighted_residuals = list(map(lambda result: result[2], results)) self._residuals = list(map(lambda result: result[3], results)) self._additional_penalty = calculate_clp_penalties( - self.model, self.parameters, self._clp_labels, self._grouped_clps, self._full_axis + self.model, + self.parameters, + self._clp_labels, + self._grouped_clps, + self._full_axis, + self.dataset_models, ) return self._reduced_clps, self._clps, self._weighted_residuals, self._residuals diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 10b03eff9..8b7fedb18 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -168,7 +168,12 @@ def _calculate_residual(self, label: str, dataset_model: DatasetModel): clp_labels = self._get_clp_labels(label) additional_penalty = calculate_clp_penalties( - self.model, self.parameters, clp_labels, self._clps[label], global_axis + self.model, + self.parameters, + clp_labels, + self._clps[label], + global_axis, + self.dataset_models, ) if additional_penalty.size != 0: self._additional_penalty.append(additional_penalty) diff --git a/glotaran/analysis/util.py b/glotaran/analysis/util.py index d54cb06b0..e25e79136 100644 --- a/glotaran/analysis/util.py +++ b/glotaran/analysis/util.py @@ -194,27 +194,48 @@ def calculate_clp_penalties( clp_labels: list[list[str]] | list[str], clps: list[np.ndarray], global_axis: np.ndarray, + dataset_models: dict[str, DatasetModel], ) -> np.ndarray: + # TODO: make a decision on how to handle clp_penalties per dataset + # 1. sum up contributions per dataset on each dataset_axis (v0.4.1) + # 2. sum up contributions on the global_axis (future?) + penalties = [] for penalty in model.clp_area_penalties: penalty = penalty.fill(model, parameters) - source_area = _get_area( - penalty.source, - clp_labels, - clps, - penalty.source_intervals, - global_axis, - ) - - target_area = _get_area( - penalty.target, - clp_labels, - clps, - penalty.target_intervals, - global_axis, - ) - + source_area = np.array([]) + target_area = np.array([]) + for _, dataset_model in dataset_models.items(): + dataset_axis = dataset_model.get_global_axis() + + source_area = np.concatenate( + [ + source_area, + _get_area( + penalty.source, + clp_labels, + clps, + penalty.source_intervals, + global_axis, + dataset_axis, + ), + ] + ) + + target_area = np.concatenate( + [ + target_area, + _get_area( + penalty.target, + clp_labels, + clps, + penalty.target_intervals, + global_axis, + dataset_axis, + ), + ] + ) area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) penalties.append(area_penalty * penalty.weight) @@ -228,14 +249,18 @@ def _get_area( clps: list[np.ndarray], intervals: list[tuple[float, float]], global_axis: np.ndarray, + dataset_axis: np.ndarray, ) -> np.ndarray: area = [] for interval in intervals: if interval[0] > global_axis[-1]: continue - - start_idx, end_idx = get_idx_from_interval(interval, global_axis) + bounded_interval = ( + max(interval[0], np.min(dataset_axis)), + min(interval[1], np.max(dataset_axis)), + ) + start_idx, end_idx = get_idx_from_interval(bounded_interval, global_axis) for i in range(start_idx, end_idx + 1): index_clp_labels = clp_labels[i] if isinstance(clp_labels[0], list) else clp_labels if clp_label in index_clp_labels: diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py index 2918d9b41..3aa08ea32 100644 --- a/glotaran/model/clp_penalties.py +++ b/glotaran/model/clp_penalties.py @@ -73,6 +73,7 @@ def apply_spectral_penalties( group_tolerance: float, ) -> np.ndarray: + # TODO: seems to duplicate calculate_clp_penalties penalties = [] for penalty in model.clp_area_penalties: From 5238409bd2a37b6cd84c344e03299ca3902147e6 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 5 Sep 2021 20:59:00 +0200 Subject: [PATCH 23/29] =?UTF-8?q?=F0=9F=A7=AA=F0=9F=9A=87=20Add=20integrat?= =?UTF-8?q?ion=20test=20result=20validation=20(#754)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🧪 Added result consistency test script * 🚇 Added result consistency testing to the integration test workflow * 👌 Propagate git interaction errors to inform users * 🔧 Add `pytest-allclose` for more usefull error reporting * 👌 Use colored pytest output for better readability * 👌 Show the commits used to create the results * 👌 Added tests for optimized parameters * 🧹 Renamed variables used for comparing to expected and current * 👌 Print up to 20 different values if "allclose" fails * Adjust tolerances to pass on CI if reference and current are the same * 👌 Make git interactions more save by passing the folder to use to git * 👌 Implemented 'EXAMPLE_BLOCKLIST' and added 'transient_absorption_two_dataset' * 👌 Only lower absolute tolerance for SVD comparison * 👌 Added 'ex_spectral_guidance' to EXAMPLE_BLOCKLIST * 📚 Added instructions to locally run result consistency test * 👌 Added dataset file name to report on failing test * 👌 Added data_var name to error reporting * 👌 Special cased missing weighted_data * 👌 Added label displaying on test_result_parameter_consistency fail * 🩹 Fixed line length issue * 👌 Use value difference to check data_vars * 👌 Swap abs_diff and float_resolution, that way rtol has some effect * 👌 Use float32 precision as absolute tolerance * 👌 Improved error reporting on failure by showing mean difference * ♻️ Refactored data_var tests not to run in a loop but using fixtures * 👌 Made epsilon for residual scale with original data * 👌 Add option to specify path for local comparison * 👌 Allow missing coords in some variables * 🩹 Reorder dimensions before comparison The same PR was merged to the v0.4.1 maintenance branch for comparison, see #760 Co-authored-by: Joris Snellenburg --- .github/test_result_consistency.py | 396 ++++++++++++++++++++++++ .github/workflows/integration-tests.yml | 52 +++- .gitignore | 5 + CONTRIBUTING.rst | 27 ++ requirements_dev.txt | 1 + tox.ini | 8 +- 6 files changed, 487 insertions(+), 2 deletions(-) create mode 100644 .github/test_result_consistency.py diff --git a/.github/test_result_consistency.py b/.github/test_result_consistency.py new file mode 100644 index 000000000..c7f2a209d --- /dev/null +++ b/.github/test_result_consistency.py @@ -0,0 +1,396 @@ +""""Tests to ensure result consistency.""" +from __future__ import annotations + +import os +import re +import subprocess +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING +from typing import Protocol +from warnings import warn + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +if TYPE_CHECKING: + from xarray.core.coordinates import DataArrayCoordinates + + +REPO_ROOT = Path(__file__).parent.parent +RUN_EXAMPLES_MSG = ( + "run 'python scripts/run_examples.py run-all --headless' " + "in the 'pyglotaran-examples' repo root." +) + +# in general this list should be empty, but for none stable examples this might be needed +EXAMPLE_BLOCKLIST = [ + "study_transient_absorption_two_dataset_analysis_result_2d_co_co2", + "ex_spectral_guidance", +] +ALLOW_MISSING_COORDS = {"spectral": ("matrix", "species_concentration")} + +SVD_PATTERN = re.compile(r"(?P.+?)(right|left)_singular_vectors") + + +class AllCloseFixture(Protocol): + def __call__( + self, + a: float | np.ndarray | xr.DataArray, + b: float | np.ndarray | xr.DataArray, + rtol: float | np.ndarray = 1e-5, + atol: float | np.ndarray = 1e-8, + xtol: int = 0, + equal_nan: bool = False, + print_fail: int = 5, + record_rmse: bool = True, + ) -> bool: + ... + + +class GitError(Exception): + """Error raised when a git interaction didn't exit with a 0 returncode.""" + + +def get_compare_results_path() -> Path: + """Ensure that the comparison-results exist, are up to date and return their path.""" + compare_result_folder = REPO_ROOT / "comparison-results" + example_repo = "git@github.com:glotaran/pyglotaran-examples.git" + if not compare_result_folder.exists(): + proc_clone = subprocess.run( + [ + "git", + "clone", + "--depth", + "1", + "-b", + "comparison-results", + example_repo, + str(compare_result_folder), + ], + capture_output=True, + ) + if proc_clone.returncode != 0: + raise GitError(f"Error cloning {example_repo}:\n{proc_clone.stderr.decode()}") + if "GITHUB" not in os.environ: + proc_fetch = subprocess.run( + [ + "git", + "-C", + compare_result_folder.as_posix(), + "fetch", + "--depth", + "1", + "origin", + "comparison-results", + ], + capture_output=True, + ) + if proc_fetch.returncode != 0: + raise GitError(f"Error fetching {example_repo}:\n{proc_fetch.stderr.decode()}") + proc_reset = subprocess.run( + [ + "git", + "-C", + compare_result_folder.as_posix(), + "reset", + "--hard", + "origin/comparison-results", + ], + capture_output=True, + ) + if proc_reset.returncode != 0: + raise GitError(f"Error resetting {example_repo}:\n{proc_reset.stderr.decode()}") + return compare_result_folder + + +def get_current_result_path() -> Path: + """Get the path of the current results.""" + local_path = Path.home() / "pyglotaran_examples_results" + ci_path = REPO_ROOT / "comparison-results-current" + if local_path.exists(): + return local_path + elif ci_path.exists(): + return ci_path + else: + raise ValueError(f"No current results present, {RUN_EXAMPLES_MSG}") + + +def coord_test( + expected_coords: DataArrayCoordinates, + current_coords: DataArrayCoordinates, + file_name: str, + allclose: AllCloseFixture, + exact_match=False, + data_var_name: str = "unknown", +) -> None: + """Run tests that coordinates are exactly equal if string coords or close.""" + for expected_coord_name, expected_coord_value in expected_coords.items(): + if ( + expected_coord_name in ALLOW_MISSING_COORDS + and data_var_name in ALLOW_MISSING_COORDS[expected_coord_name] + ): + print(f"- allow missing coordinate: {expected_coord_name} in variable {data_var_name}") + continue + assert expected_coord_name in current_coords.keys(), ( + f"Missing coordinate: {expected_coord_name!r} in {file_name!r}, " + f"data_var {data_var_name!r}" + ) + + # assert ( + # expected_coord_value.dims == current_coords.dims + # ), f"Dimensions mismatch in {data_var_name!r}" + + if exact_match or expected_coord_value.data.dtype == object: + assert np.array_equal( + expected_coord_value, current_coords[expected_coord_name] + ), f"Coordinate value mismatch in {file_name!r}, data_var {data_var_name!r}" + else: + assert allclose( + expected_coord_value, current_coords[expected_coord_name], rtol=1e-5, print_fail=20 + ), f"Coordinate value mismatch in {file_name!r}, data_var {data_var_name!r}" + + +def data_var_test( + allclose: AllCloseFixture, + expected_result: xr.Dataset, + current_result: xr.Dataset, + file_name: str, + expected_var_name: str, +) -> None: + """Run test that a data_var of the current_result is close to the expected_result.""" + expected_var_value = expected_result.data_vars[expected_var_name] + + # weighted_data were always calculated and now will only be calculated + # when weights are applied + if expected_var_name == "weighted_data" and expected_var_name not in current_result.data_vars: + return + + assert ( + expected_var_name in current_result.data_vars + ), f"Missing data_var: {expected_var_name!r} in {file_name!r}" + current_data = current_result.data_vars[expected_var_name] + expected_values = expected_var_value + current_values = current_data + + eps = np.finfo(np.float32).eps + rtol = 1e-5 # default value of allclose + if expected_var_name.endswith("residual"): # type:ignore[operator] + eps = expected_result["data"].values.max() * 1e-8 + + if "singular_vectors" in expected_var_name: # type:ignore[operator] + # Sometimes the coords in the (right) singular vectors are swapped + if expected_values.dims != current_values.dims: + warn( + dedent( + f"""\n + Dimensions transposed for {expected_var_name!r} in {file_name!r}. + - expected: {expected_values.dims} + - current: {current_values.dims} + """ + ) + ) + expected_values = expected_values.transpose(*current_values.dims) + rtol = 1e-4 # instead of 1e-5 + eps = 1e-5 # instead of ~1.2e-7 + pre_fix = SVD_PATTERN.match(expected_var_name).group( # type:ignore[operator] + "pre_fix" + ) + expected_singular_values = expected_result.data_vars[f"{pre_fix}singular_values"] + + if expected_var_value.shape[0] == expected_singular_values.shape[0]: + expected_values_scaled = np.diag(expected_singular_values).dot(expected_var_value.data) + else: + expected_values_scaled = expected_var_value.data.dot(np.diag(expected_singular_values)) + + float_resolution = np.maximum( + np.abs(eps * expected_values_scaled), + np.ones(expected_var_value.data.shape) * eps, + ) + else: + float_resolution = np.maximum( + np.abs(eps * expected_var_value.data), + np.ones(expected_var_value.data.shape) * eps, + ) + abs_diff = np.abs(expected_values - current_values) + + assert allclose( + expected_values, + current_values, + atol=float_resolution, + rtol=rtol, + print_fail=20, + ), ( + f"Result data_var data mismatch: {expected_var_name!r} in {file_name!r}.\n" + "With sum of absolute difference: " + f"{float(np.sum(abs_diff))} and shape: {expected_var_value.shape}\n" + "Mean difference: " + f"{float(np.sum(abs_diff))/np.prod(expected_var_value.shape)}\n" + ) + + coord_test( + expected_var_value.coords, + current_data.coords, + file_name, + allclose, + data_var_name=expected_var_name, # type:ignore[operator] + ) + + +def map_result_files(file_glob_pattern: str) -> dict[str, list[tuple[Path, Path]]]: + """Load all datasets and map them in a dict.""" + result_map = defaultdict(list) + if os.getenv("COMPARE_RESULTS_LOCAL"): + compare_results_path = Path(os.getenv(key="COMPARE_RESULTS_LOCAL")) + warn( + dedent( + f""" + Using Path in environment variable COMPARE_RESULTS_LOCAL: + {compare_results_path.as_posix()} + """ + ) + ) + try: + if not compare_results_path.exists(): + raise FileNotFoundError( + dedent( + f""" + Path in COMPARE_RESULTS_LOCAL not valid: + {compare_results_path} <- does not exist + """ + ) + ) + except OSError as exception: + if str(compare_results_path).startswith(('"', "'")): + raise Exception( + "Path in COMPARE_RESULTS_LOCAL should not start with ' or \"" + ) from exception + raise exception + else: + compare_results_path = get_compare_results_path() + current_result_path = get_current_result_path() + for expected_result_file in compare_results_path.rglob(file_glob_pattern): + key = ( + expected_result_file.relative_to(compare_results_path) + .parent.as_posix() + .replace("/", "_") + ) + if key in EXAMPLE_BLOCKLIST: + continue + current_result_file = current_result_path / expected_result_file.relative_to( + compare_results_path + ) + if current_result_file.exists(): + result_map[key].append((expected_result_file, current_result_file)) + else: + warn( + UserWarning( + f"No current result for: {expected_result_file.as_posix()}, {RUN_EXAMPLES_MSG}" + ) + ) + return result_map + + +@lru_cache(maxsize=1) +def map_result_data() -> tuple[dict[str, list[tuple[xr.Dataset, xr.Dataset, str]]], set[str]]: + """Load all datasets and map them in a tuple of dict and set of data_var names.""" + result_map = defaultdict(list) + data_var_names = set() + result_file_map = map_result_files(file_glob_pattern="*.nc") + for key, path_list in result_file_map.items(): + for expected_result_file, current_result_file in path_list: + expected_result: xr.Dataset = xr.open_dataset(expected_result_file) + result_map[key].append( + ( + expected_result, + xr.open_dataset(current_result_file), + expected_result_file.name, + ) + ) + for data_var_name in expected_result.data_vars.keys(): + if data_var_name != "data": + data_var_names.add(data_var_name) + return result_map, data_var_names + + +@lru_cache(maxsize=1) +def map_result_parameters() -> dict[str, list[pd.DataFrame]]: + """Load all optimized parameter files and map them in a dict.""" + + result_map = defaultdict(list) + result_file_map = map_result_files(file_glob_pattern="*.csv") + for key, path_list in result_file_map.items(): + for expected_result_file, current_result_file in path_list: + compare_df = pd.DataFrame( + { + "expected": pd.read_csv(expected_result_file, index_col="label")["value"], + "current": pd.read_csv(current_result_file, index_col="label")["value"], + } + ) + result_map[key].append(compare_df) + return result_map + + +@pytest.mark.parametrize("result_name", map_result_data()[0].keys()) +def test_original_data_exact_consistency( + allclose: AllCloseFixture, + result_name: str, +): + """The original data need to be exactly the same.""" + for expected_result, current_result, file_name in map_result_data()[0][result_name]: + assert np.array_equal( + expected_result.data.data, current_result.data.data + ), f"Original data mismatch: {result_name!r} in {file_name!r}" + coord_test( + expected_result.data.coords, + current_result.data.coords, + file_name, + allclose, + exact_match=True, + data_var_name="data", + ) + + +@pytest.mark.parametrize("result_name", map_result_parameters().keys()) +def test_result_parameter_consistency( + allclose: AllCloseFixture, + result_name: str, +): + """Optimized parameters need to be approximately the same""" + for compare_df in map_result_parameters()[result_name]: + assert allclose( + compare_df["expected"].values, compare_df["current"].values, print_fail=20 + ), f"Parameter Mismatch: {compare_df.index}" + + +@pytest.mark.parametrize("result_name", map_result_data()[0].keys()) +def test_result_attr_consistency( + allclose: AllCloseFixture, + result_name: str, +): + """Resultdataset attributes need to be approximately the same.""" + for expected, current, file_name in map_result_data()[0][result_name]: + for expected_attr_name, expected_attr_value in expected.attrs.items(): + + assert ( + expected_attr_name in current.attrs.keys() + ), f"Missing result attribute: {expected_attr_name!r} in {file_name!r}" + + assert allclose( + expected_attr_value, current.attrs[expected_attr_name], print_fail=20 + ), f"Result attr value mismatch: {expected_attr_name!r} in {file_name!r}" + + +@pytest.mark.parametrize("expected_var_name", map_result_data()[1]) +@pytest.mark.parametrize("result_name", map_result_data()[0].keys()) +def test_result_data_var_consistency( + allclose: AllCloseFixture, result_name: str, expected_var_name: str +): + """Result dataset data variables need to be approximately the same.""" + for expected_result, current_result, file_name in map_result_data()[0][result_name]: + if expected_var_name in expected_result.data_vars.keys(): + data_var_test(allclose, expected_result, current_result, file_name, expected_var_name) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 8c02f9ef9..7c1e7c5e9 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -43,5 +43,55 @@ jobs: - name: Upload Example Plots Artifact uses: actions/upload-artifact@v2 with: - name: example-results + name: example-plots path: ${{ steps.example-run.outputs.plots-path }} + + - name: Upload Example Results + uses: actions/upload-artifact@v2 + with: + name: example-results + path: ~/pyglotaran_examples_results + + compare-results: + name: Compare Results + runs-on: ubuntu-latest + needs: [run-examples] + steps: + - name: Checkout glotaran + uses: actions/checkout@v2 + + - name: Checkout compare results + uses: actions/checkout@v2 + with: + repository: "glotaran/pyglotaran-examples" + ref: comparison-results + path: comparison-results + + - name: Download result artifact + uses: actions/download-artifact@v2 + with: + name: example-results + path: comparison-results-current + + - name: Show used versions for result creation + run: | + echo "::group:: ✔️ Compare-Results" + echo "✔️ pyglotaran-examples commit: $(< comparison-results/example_commit_sha.txt)" + echo "✔️ pyglotaran commit: $(< comparison-results/pyglotaran_commit_sha.txt)" + echo "::endgroup::" + echo "::group:: ♻️ Current-Results" + echo "♻️ pyglotaran-examples commit: $(< comparison-results-current/example_commit_sha.txt)" + echo "::endgroup::" + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + pip install xarray pytest pytest-allclose netCDF4 + + - name: Compare Results + run: | + python -m pytest --color=yes .github/test_result_consistency.py diff --git a/.gitignore b/.gitignore index 6329ca758..9211a0b5b 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,8 @@ _summary.ps # benchmark results benchmark/.asv .benchmarks/ + +# results to validate result consistency +# https://github.com/glotaran/pyglotaran-examples/tree/comparison-results +comparison-results +comparison-results-current diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 991396fe9..155e9f485 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -303,6 +303,7 @@ as an attribute to the parent package. to_be_removed_in_version="0.6.0", ) + Deprecating dict entries ~~~~~~~~~~~~~~~~~~~~~~~~ The possible dict deprecation actions are: @@ -314,6 +315,32 @@ The possible dict deprecation actions are: For full examples have a look at the examples from the docstring (:func:`deprecate_dict_entry`). + +Testing Result consistency +-------------------------- +To test the consistency of results locally you need to clone the +`pyglotaran-examples `_ +and run them:: + + $ git clone https://github.com/glotaran/pyglotaran-examples + $ cd pyglotaran-examples + $ python scripts/run_examples.py run-all --headless + +.. note:: + Make sure you got the the latest version (``git pull``) and are + on the correct branch for both ``pyglotaran`` and ``pyglotaran-examples``. + +The results from the examples will be saved in you home folder under ``pyglotaran_examples_results``. +Those results than will be compared to the 'gold standard' defined by the maintainers. + +To test the result consistency run:: + + $ pytest .github/test_result_consistency.py + +If needed this will clone the `'gold standard' results `_ +to the folder ``comparison-results``, update them and test your current results against them. + + Deploying --------- diff --git a/requirements_dev.txt b/requirements_dev.txt index 1a2b4877c..e63b6b521 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -31,6 +31,7 @@ pytest-cov>=2.5.1 pytest-env>=0.6.2 pytest-runner>=2.11.1 pytest-benchmark>=3.1.1 +pytest-allclose>=1.0.0 # code quality assurance flake8>=3.8.3 diff --git a/tox.ini b/tox.ini index 4a4dfeee2..71ac77a40 100644 --- a/tox.ini +++ b/tox.ini @@ -7,12 +7,18 @@ envlist = py{38}, pre-commit, docs, docs-notebooks, docs-links [pytest] ; Uncomment the following lines to deactivate pyglotaran all plugins ; env = -; DEACTIVATE_GTA_PLUGINS=1 +; DEACTIVATE_GTA_PLUGINS=1 +; Uncomment "env =" and "COMPARE_RESULTS_LOCAL" and set it to a local folder +; with results to use as a reference in lieu of the comparison-results branch +; in the pyglotaran-examples git repository +; COMPARE_RESULTS_LOCAL=~/local_results/ ; On *nix +; COMPARE_RESULTS_LOCAL=%USERPROFILE%/local_results/ ; On Windows ; Uncomment to ignore deprecation warnings coming from pyglotaran ; (this helps to see the warnings from dependencies) ; filterwarnings = ; ignore:.+glotaran:GlotaranApiDeprecationWarning + [flake8] extend-ignore = E231, E203 max-line-length = 99 From 902454ead5770f8dba147f37411c291ac1dc94f4 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Wed, 8 Sep 2021 00:49:40 +0200 Subject: [PATCH 24/29] =?UTF-8?q?=F0=9F=94=A7=20Use=20flake8-print=20in=20?= =?UTF-8?q?pre-commit=20(#772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 Use flake8-print in pre-commit * 🧹 Rerun pre-commit after rebase --- .pre-commit-config.yaml | 3 +- glotaran/analysis/test/test_constraints.py | 4 +-- glotaran/analysis/test/test_optimization.py | 36 +++++++++---------- glotaran/analysis/test/test_penalties.py | 2 +- glotaran/analysis/test/test_problem.py | 8 ++--- glotaran/analysis/test/test_relations.py | 4 +-- .../io/ascii/wavelength_time_explicit_file.py | 7 ++-- .../builtin/io/yml/test/test_model_parser.py | 6 ++-- .../test/test_doas_model.py | 16 ++++----- .../spectral/test/test_spectral_model.py | 12 +++---- glotaran/model/test/test_model.py | 10 +++--- glotaran/test/test_spectral_decay.py | 14 ++++---- .../test/test_spectral_decay_full_model.py | 16 ++++----- glotaran/test/test_spectral_penalties.py | 18 +++++----- tox.ini | 2 ++ 15 files changed, 79 insertions(+), 79 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 997e06f46..d2ce60384 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -121,7 +121,8 @@ repos: - id: flake8 types: [file] types_or: [python, pyi] - additional_dependencies: [flake8-pyi, flake8-comprehensions] + additional_dependencies: + [flake8-pyi, flake8-comprehensions, flake8-print] - repo: https://github.com/myint/rstcheck rev: "3f92957478422df87bd730abde66f089cc1ee19b" diff --git a/glotaran/analysis/test/test_constraints.py b/glotaran/analysis/test/test_constraints.py index 07b0c9611..82f80fd6a 100644 --- a/glotaran/analysis/test/test_constraints.py +++ b/glotaran/analysis/test/test_constraints.py @@ -17,7 +17,7 @@ def test_constraint(index_dependent, grouped): model.megacomplex["m1"].is_index_dependent = index_dependent model.constraints.append(ZeroConstraint.from_dict({"target": "s2"})) - print("grouped", grouped, "index_dependent", index_dependent) # noqa T001 + print("grouped", grouped, "index_dependent", index_dependent) dataset = simulate( suite.sim_model, "dataset1", @@ -36,7 +36,7 @@ def test_constraint(index_dependent, grouped): matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] result_data = problem.create_result_data() - print(result_data) # noqa T001 + print(result_data) clps = result_data["dataset1"].clp assert "s2" not in reduced_matrix.clp_labels diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 80d4268c9..2576b02e3 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -32,26 +32,26 @@ def test_optimization(suite, index_dependent, grouped, weight, method): model.megacomplex["m1"].is_index_dependent = index_dependent - print("Grouped:", grouped) # noqa T001 - print("Index dependent:", index_dependent) # noqa T001 + print("Grouped:", grouped) + print("Index dependent:", index_dependent) sim_model = suite.sim_model sim_model.megacomplex["m1"].is_index_dependent = index_dependent - print(model.validate()) # noqa T001 + print(model.validate()) assert model.valid() - print(sim_model.validate()) # noqa T001 + print(sim_model.validate()) assert sim_model.valid() wanted_parameters = suite.wanted_parameters - print(wanted_parameters) # noqa T001 - print(sim_model.validate(wanted_parameters)) # noqa T001 + print(wanted_parameters) + print(sim_model.validate(wanted_parameters)) assert sim_model.valid(wanted_parameters) initial_parameters = suite.initial_parameters - print(initial_parameters) # noqa T001 - print(model.validate(initial_parameters)) # noqa T001 + print(initial_parameters) + print(model.validate(initial_parameters)) assert model.valid(initial_parameters) assert ( model.dataset["dataset1"].fill(model, initial_parameters).index_dependent() @@ -70,9 +70,9 @@ def test_optimization(suite, index_dependent, grouped, weight, method): wanted_parameters, {"global": global_axis, "model": model_axis}, ) - print(f"Dataset {i+1}") # noqa T001 - print("=============") # noqa T001 - print(dataset) # noqa T001 + print(f"Dataset {i+1}") + print("=============") + print(dataset) if hasattr(suite, "scale"): dataset["data"] /= suite.scale @@ -97,7 +97,7 @@ def test_optimization(suite, index_dependent, grouped, weight, method): ) result = optimize(scheme, raise_exception=True) - print(result.optimized_parameters) # noqa T001 + print(result.optimized_parameters) assert result.success optimized_scheme = result.get_scheme() assert result.optimized_parameters == optimized_scheme.parameters @@ -111,9 +111,9 @@ def test_optimization(suite, index_dependent, grouped, weight, method): for i, dataset in enumerate(data.values()): resultdata = result.data[f"dataset{i+1}"] - print(f"Result Data {i+1}") # noqa T001 - print("=================") # noqa T001 - print(resultdata) # noqa T001 + print(f"Result Data {i+1}") + print("=================") + print(resultdata) assert "residual" in resultdata assert "residual_left_singular_vectors" in resultdata assert "residual_right_singular_vectors" in resultdata @@ -121,7 +121,7 @@ def test_optimization(suite, index_dependent, grouped, weight, method): assert np.array_equal(dataset.coords["model"], resultdata.coords["model"]) assert np.array_equal(dataset.coords["global"], resultdata.coords["global"]) assert dataset.data.shape == resultdata.data.shape - print(dataset.data[0, 0], resultdata.data[0, 0]) # noqa T001 + print(dataset.data[0, 0], resultdata.data[0, 0]) assert np.allclose(dataset.data, resultdata.data) if weight: assert "weight" in resultdata @@ -136,7 +136,7 @@ def test_optimization_full_model(index_dependent): model = FullModel.model model.megacomplex["m1"].is_index_dependent = index_dependent - print(model.validate()) # noqa T001 + print(model.validate()) assert model.valid() parameters = FullModel.parameters @@ -164,6 +164,6 @@ def test_optimization_full_model(index_dependent): assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) clp = result_data.clp - print(clp) # noqa T001 + print(clp) assert clp.shape == (4, 4) assert all(np.isclose(1.0, c) for c in np.diagonal(clp)) diff --git a/glotaran/analysis/test/test_penalties.py b/glotaran/analysis/test/test_penalties.py index 3c3c3c1b2..43f5c4874 100644 --- a/glotaran/analysis/test/test_penalties.py +++ b/glotaran/analysis/test/test_penalties.py @@ -33,7 +33,7 @@ def test_penalties(index_dependent, grouped): global_axis = np.arange(50) - print("grouped", grouped, "index_dependent", index_dependent) # T001 + print("grouped", grouped, "index_dependent", index_dependent) dataset = simulate( suite.sim_model, "dataset1", diff --git a/glotaran/analysis/test/test_problem.py b/glotaran/analysis/test/test_problem.py index 241618b2a..e3cac4504 100644 --- a/glotaran/analysis/test/test_problem.py +++ b/glotaran/analysis/test/test_problem.py @@ -147,7 +147,7 @@ def test_prepare_data(): ], } model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) # T001 # noqa T001 + print(model.validate()) assert model.valid() parameters = ParameterGroup.from_list([]) @@ -165,7 +165,7 @@ def test_prepare_data(): problem = Problem(scheme) data = problem.data["dataset1"] - print(data) # noqa T001 + print(data) assert "data" in data assert "weight" in data @@ -181,7 +181,7 @@ def test_prepare_data(): } ) model = SimpleTestModel.from_dict(model_dict) - print(model.validate()) # T001 # noqa T001 + print(model.validate()) assert model.valid() scheme = Scheme(model, parameters, {"dataset1": dataset}) @@ -214,5 +214,5 @@ def test_full_model_problem(): clp = result.clp assert clp.shape == (4, 4) - print(np.diagonal(clp)) # noqa T001 + print(np.diagonal(clp)) assert all(np.isclose(1.0, c) for c in np.diagonal(clp)) diff --git a/glotaran/analysis/test/test_relations.py b/glotaran/analysis/test/test_relations.py index ccc5f6d05..480db3124 100644 --- a/glotaran/analysis/test/test_relations.py +++ b/glotaran/analysis/test/test_relations.py @@ -19,7 +19,7 @@ def test_relations(index_dependent, grouped): model.relations.append(Relation.from_dict({"source": "s1", "target": "s2", "parameter": "3"})) parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) - print("grouped", grouped, "index_dependent", index_dependent) # T001 + print("grouped", grouped, "index_dependent", index_dependent) dataset = simulate( suite.sim_model, "dataset1", @@ -38,7 +38,7 @@ def test_relations(index_dependent, grouped): matrix = problem.matrices["dataset1"][0] if index_dependent else problem.matrices["dataset1"] result_data = problem.create_result_data() - print(result_data) # T001 + print(result_data) clps = result_data["dataset1"].clp assert "s2" not in reduced_matrix.clp_labels diff --git a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py index 4b11b63b2..9bceecde5 100644 --- a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py +++ b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py @@ -75,8 +75,7 @@ def write( # TODO: write a more elegant method if os.path.isfile(self._file) and not overwrite: - print(f"File {os.path.isfile(self._file)} already exists") - raise Exception("File already exist.") + raise FileExistsError(f"File already exist:\n{self._file}") comment = self._comment + " " + comment comments = "# Filename: " + str(self._file) + "\n" + " ".join(comment.splitlines()) + "\n" @@ -215,10 +214,8 @@ def get_format_name(self): def get_interval_number(line): - interval_number = None match = re.search(r"intervalnr\s(.*)", line.strip().lower()) - if match: - interval_number = match.group(1) + interval_number = match.group(1) if match else None if not interval_number: interval_number = re.search(r"\d+", line[::-1]).group()[::-1] try: diff --git a/glotaran/builtin/io/yml/test/test_model_parser.py b/glotaran/builtin/io/yml/test/test_model_parser.py index be4ac0ce3..cc98520ef 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -25,7 +25,7 @@ def model(): spec_path = join(THIS_DIR, "test_model_spec.yml") m = load_model(spec_path) - print(m.markdown()) # noqa + print(m.markdown()) return m @@ -61,7 +61,7 @@ def test_dataset(model): def test_constraints(model): - print(model.constraints) # noqa + print(model.constraints) assert len(model.constraints) == 2 zero = model.constraints[0] @@ -88,7 +88,7 @@ def test_penalties(model): def test_relations(model): - print(model.relations) # noqa + print(model.relations) assert len(model.relations) == 1 rel = model.relations[0] diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py index fe1c29927..91a07519c 100755 --- a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py +++ b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py @@ -364,25 +364,25 @@ class OneOscillationWithSequentialModel: ) def test_doas_model(suite): - print(suite.sim_model.validate()) # noqa + print(suite.sim_model.validate()) assert suite.sim_model.valid() - print(suite.model.validate()) # noqa + print(suite.model.validate()) assert suite.model.valid() - print(suite.sim_model.validate(suite.wanted_parameter)) # noqa + print(suite.sim_model.validate(suite.wanted_parameter)) assert suite.sim_model.valid(suite.wanted_parameter) - print(suite.model.validate(suite.parameter)) # noqa + print(suite.model.validate(suite.parameter)) assert suite.model.valid(suite.parameter) dataset = simulate(suite.sim_model, "dataset1", suite.wanted_parameter, suite.axis) - print(dataset) # noqa + print(dataset) assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) - print(suite.parameter) # noqa - print(suite.wanted_parameter) # noqa + print(suite.parameter) + print(suite.wanted_parameter) data = {"dataset1": dataset} scheme = Scheme( @@ -392,7 +392,7 @@ def test_doas_model(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme, raise_exception=True) - print(result.optimized_parameters) # noqa + print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, suite.wanted_parameter.get(label).value, rtol=1e-1) diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index 4b5659943..0764ad87a 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -280,19 +280,19 @@ class ThreeCompartmentModel: def test_spectral_model(suite): model = suite.spectral_model - print(model.validate()) # noqa + print(model.validate()) assert model.valid() wanted_parameters = suite.spectral_parameters - print(model.validate(wanted_parameters)) # noqa - print(wanted_parameters) # noqa + print(model.validate(wanted_parameters)) + print(wanted_parameters) assert model.valid(wanted_parameters) initial_parameters = suite.spectral_parameters - print(model.validate(initial_parameters)) # noqa + print(model.validate(initial_parameters)) assert model.valid(initial_parameters) - print(model.markdown(initial_parameters)) # noqa + print(model.markdown(initial_parameters)) dataset = simulate(model, "dataset1", wanted_parameters, suite.axis, suite.clp) @@ -307,7 +307,7 @@ def test_spectral_model(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme) - print(result.optimized_parameters) # noqa + print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index c28c34c50..b8a446b1d 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -262,13 +262,13 @@ def test_model_misc(test_model: Model): def test_model_validity(test_model: Model, model_error: Model, parameter: ParameterGroup): - print(test_model.test_item1["t1"]) # noqa T001 - print(test_model.problem_list()) # noqa T001 - print(test_model.problem_list(parameter)) # noqa T001 + print(test_model.test_item1["t1"]) + print(test_model.problem_list()) + print(test_model.problem_list(parameter)) assert test_model.valid() assert test_model.valid(parameter) - print(model_error.problem_list()) # noqa T001 - print(model_error.problem_list(parameter)) # noqa T001 + print(model_error.problem_list()) + print(model_error.problem_list(parameter)) assert not model_error.valid() assert len(model_error.problem_list()) == 5 assert not model_error.valid(parameter) diff --git a/glotaran/test/test_spectral_decay.py b/glotaran/test/test_spectral_decay.py index a23ce515d..bede508b8 100644 --- a/glotaran/test/test_spectral_decay.py +++ b/glotaran/test/test_spectral_decay.py @@ -243,19 +243,19 @@ class ThreeComponentSequential: def test_kinetic_model(suite, nnls): model = suite.model - print(model.validate()) # noqa T001 + print(model.validate()) assert model.valid() wanted_parameters = suite.wanted_parameters - print(model.validate(wanted_parameters)) # noqa T001 - print(wanted_parameters) # noqa T001 + print(model.validate(wanted_parameters)) + print(wanted_parameters) assert model.valid(wanted_parameters) initial_parameters = suite.initial_parameters - print(model.validate(initial_parameters)) # noqa T001 + print(model.validate(initial_parameters)) assert model.valid(initial_parameters) - print(model.markdown(wanted_parameters)) # noqa T001 + print(model.markdown(wanted_parameters)) dataset = simulate(model, "dataset1", wanted_parameters, suite.axis) @@ -272,14 +272,14 @@ def test_kinetic_model(suite, nnls): group=False, ) result = optimize(scheme) - print(result.optimized_parameters) # noqa T001 + print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, wanted_parameters.get(label).value) resultdata = result.data["dataset1"] - print(resultdata) # noqa T001 + print(resultdata) assert np.array_equal(dataset["time"], resultdata["time"]) assert np.array_equal(dataset["spectral"], resultdata["spectral"]) diff --git a/glotaran/test/test_spectral_decay_full_model.py b/glotaran/test/test_spectral_decay_full_model.py index 28d441db9..b21df0e52 100644 --- a/glotaran/test/test_spectral_decay_full_model.py +++ b/glotaran/test/test_spectral_decay_full_model.py @@ -178,19 +178,19 @@ class ThreeComponentSequential: def test_kinetic_model(suite, nnls): model = suite.model - print(model.validate()) # noqa T001 + print(model.validate()) assert model.valid() wanted_parameters = suite.wanted_parameters - print(model.validate(wanted_parameters)) # noqa T001 - print(wanted_parameters) # noqa T001 + print(model.validate(wanted_parameters)) + print(wanted_parameters) assert model.valid(wanted_parameters) initial_parameters = suite.initial_parameters - print(model.validate(initial_parameters)) # noqa T001 + print(model.validate(initial_parameters)) assert model.valid(initial_parameters) - print(model.markdown(wanted_parameters)) # noqa T001 + print(model.markdown(wanted_parameters)) dataset = simulate(model, "dataset1", wanted_parameters, suite.axis) @@ -207,15 +207,15 @@ def test_kinetic_model(suite, nnls): group=False, ) result = optimize(scheme) - print(result.optimized_parameters) # noqa T001 + print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): - print(label, param.value, wanted_parameters.get(label).value) # noqa T001 + print(label, param.value, wanted_parameters.get(label).value) assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) resultdata = result.data["dataset1"] - print(resultdata) # noqa T001 + print(resultdata) assert np.array_equal(dataset["time"], resultdata["time"]) assert np.array_equal(dataset["spectral"], resultdata["spectral"]) diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index a9a870555..2d07f5650 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -204,7 +204,7 @@ def test_equal_area_penalties(debug=False): model_sim = SpectralDecayModel.from_dict(mspec_sim) model_wp = SpectralDecayModel.from_dict(mspec_fit_wp) model_np = SpectralDecayModel.from_dict(mspec_fit_np) - print(model_np) # noqa T001 + print(model_np) # %% Parameter specification (pspec) @@ -224,9 +224,9 @@ def test_equal_area_penalties(debug=False): param_np = ParameterGroup.from_dict(pspec_np) # %% Print models with parameters - print(model_sim.markdown(param_sim)) # noqa T001 - print(model_wp.markdown(param_wp)) # noqa T001 - print(model_np.markdown(param_np)) # noqa T001 + print(model_sim.markdown(param_sim)) + print(model_wp.markdown(param_wp)) + print(model_np.markdown(param_np)) # %% simulated_data = simulate( @@ -254,7 +254,7 @@ def test_equal_area_penalties(debug=False): maximum_number_function_evaluations=optim_spec.max_nfev, ) result_np = optimize(scheme_np) - print(result_np) # noqa T001 + print(result_np) # %% Optimizing model with penalty fixed inputs (wp_ifix) scheme_wp = Scheme( @@ -265,7 +265,7 @@ def test_equal_area_penalties(debug=False): maximum_number_function_evaluations=optim_spec.max_nfev, ) result_wp = optimize(scheme_wp) - print(result_wp) # noqa T001 + print(result_wp) if debug: # %% Plot results @@ -278,10 +278,10 @@ def test_equal_area_penalties(debug=False): plt.show() # %% Test calculation - print(result_wp.data["dataset1"]) # noqa T001 + print(result_wp.data["dataset1"]) area1_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s1")) area2_np = np.sum(result_np.data["dataset1"].species_associated_spectra.sel(species="s2")) - print("area_np", area1_np, area2_np) # noqa T001 + print("area_np", area1_np, area2_np) assert not np.isclose(area1_np, area2_np) area1_wp = np.sum(result_wp.data["dataset1"].species_associated_spectra.sel(species="s1")) @@ -291,7 +291,7 @@ def test_equal_area_penalties(debug=False): input_ratio = result_wp.optimized_parameters.get("i.1") / result_wp.optimized_parameters.get( "i.2" ) - print("input", input_ratio) # noqa T001 + print("input", input_ratio) assert np.isclose(input_ratio, 1.5038858115) diff --git a/tox.ini b/tox.ini index 71ac77a40..9972404d9 100644 --- a/tox.ini +++ b/tox.ini @@ -29,6 +29,8 @@ per-file-ignores = docs/source/conf.py: E501 # Typedef files are formatted differently *.pyi: E301, E302, F401 + # Allow printing in test file + test_*.py: T001 [testenv:docs] direct = true From 9d49c125bedba140683bd6bd1773cd41fa91a486 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Thu, 9 Sep 2021 02:26:42 +0200 Subject: [PATCH 25/29] Fix codacy issues on staging (#806) * Fix Parameters differ from overridden 'index_dependent' method Adapt to change in base class Megacomplex * Fix Parameters differ from overridden 'applies' method Change name of argument from index to value and type to float * Fix Dangerous default value {} as argument Avoid the use of mutable types as default argument --- .../megacomplexes/baseline/baseline_megacomplex.py | 2 +- .../coherent_artifact_megacomplex.py | 2 +- .../megacomplexes/spectral/spectral_megacomplex.py | 2 +- glotaran/model/constraint.py | 13 ++++--------- glotaran/model/interval_property.py | 2 +- glotaran/model/item.py | 5 ++++- 6 files changed, 12 insertions(+), 14 deletions(-) diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py index 7e1a35bb4..effeeedaf 100644 --- a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -21,7 +21,7 @@ def calculate_matrix( matrix = np.ones((model_axis.size, 1), dtype=np.float64) return clp_label, matrix - def index_dependent(self, dataset: DatasetModel) -> bool: + def index_dependent(self, dataset_model: DatasetModel) -> bool: return False def finalize_data( diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index 4f113a929..a1aa8a2d8 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -59,7 +59,7 @@ def calculate_matrix( def compartments(self): return [f"coherent_artifact_{i}" for i in range(1, self.order + 1)] - def index_dependent(self, dataset: DatasetModel) -> bool: + def index_dependent(self, dataset_model: DatasetModel) -> bool: return False def finalize_data( diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index ab2f70900..51242fca7 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -52,7 +52,7 @@ def calculate_matrix( return compartments, matrix - def index_dependent(self, dataset: DatasetModel) -> bool: + def index_dependent(self, dataset_model: DatasetModel) -> bool: return False def finalize_data( diff --git a/glotaran/model/constraint.py b/glotaran/model/constraint.py index 948f581ca..0c3fc80d2 100644 --- a/glotaran/model/constraint.py +++ b/glotaran/model/constraint.py @@ -1,15 +1,10 @@ """This package contains compartment constraint items.""" from __future__ import annotations -from typing import TYPE_CHECKING - from glotaran.model.interval_property import IntervalProperty from glotaran.model.item import model_item from glotaran.model.item import model_item_typed -if TYPE_CHECKING: - from typing import Any - @model_item( properties={ @@ -22,20 +17,20 @@ class OnlyConstraint(IntervalProperty): """A only constraint sets the calculated matrix row of a compartment to 0 outside the given intervals.""" - def applies(self, index: Any) -> bool: + def applies(self, value: float) -> bool: """ - Returns true if the indexx is in one of the intervals. + Returns true if ``value`` is in one of the intervals. Parameters ---------- - index : + index : float Returns ------- applies : bool """ - return not super().applies(index) + return not super().applies(value) @model_item( diff --git a/glotaran/model/interval_property.py b/glotaran/model/interval_property.py index 0b3d43bd2..7549e9529 100644 --- a/glotaran/model/interval_property.py +++ b/glotaran/model/interval_property.py @@ -21,7 +21,7 @@ class IntervalProperty: def applies(self, value: float) -> bool: """ - Returns true if the index is in one of the intervals. + Returns true if ``value`` is in one of the intervals. Parameters ---------- diff --git a/glotaran/model/item.py b/glotaran/model/item.py index 21dd6188f..f3133090f 100644 --- a/glotaran/model/item.py +++ b/glotaran/model/item.py @@ -29,7 +29,7 @@ def model_item( - properties: Any | dict[str, dict[str, Any]] = {}, + properties: None | dict[str, dict[str, Any]] = None, has_type: bool = False, has_label: bool = True, ) -> Callable: @@ -59,6 +59,9 @@ def model_item( If false no label property will be added. """ + if properties is None: + properties = {} + def decorator(cls): setattr(cls, "_glotaran_has_label", has_label) From 9892b6efeac0cfc97b624a73b01f8321cfb36604 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sat, 11 Sep 2021 13:04:28 +0200 Subject: [PATCH 26/29] =?UTF-8?q?=F0=9F=A7=AA=20Added=20more=20tools=20fro?= =?UTF-8?q?m=20the=20'pygrep-hooks'=20pre-commit=20hook=20(#805)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 7 +++++++ glotaran/deprecation/deprecation_utils.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2ce60384..7f1d4d407 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -134,6 +134,13 @@ repos: rev: v1.9.0 hooks: - id: rst-backticks + - id: python-check-blanket-noqa + exclude: "docs|tests?" + - id: python-check-blanket-type-ignore + exclude: "docs|tests?" + - id: python-use-type-annotations + - id: rst-directive-colons + - id: rst-inline-touching-normal - repo: https://github.com/codespell-project/codespell rev: v2.1.0 diff --git a/glotaran/deprecation/deprecation_utils.py b/glotaran/deprecation/deprecation_utils.py index d6547116b..69db4d88c 100644 --- a/glotaran/deprecation/deprecation_utils.py +++ b/glotaran/deprecation/deprecation_utils.py @@ -98,7 +98,7 @@ def parse_version(version_str: str) -> tuple[int, int, int]: if len(split_version) < 3: raise ValueError(error_message) try: - return tuple(map(int, split_version[:3])) # type:ignore [return-value] + return tuple(map(int, split_version[:3])) # type:ignore[return-value] except ValueError: raise ValueError(error_message) @@ -346,7 +346,7 @@ def outer_wrapper(deprecated_object: DecoratedCallable) -> DecoratedCallable: setattr( deprecated_object, "__new__", - inject_warn_into_call(deprecated_object.__new__), # type: ignore [arg-type] + inject_warn_into_call(deprecated_object.__new__), # type: ignore[arg-type] ) return deprecated_object # type: ignore[return-value] From ae63eb45e4b895b2463989e95859a0b6316fb40e Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sat, 11 Sep 2021 23:41:43 +0200 Subject: [PATCH 27/29] =?UTF-8?q?=F0=9F=A9=B9=20Fix=20coherent=20artifact?= =?UTF-8?q?=20crash=20for=20index=20dependent=20models=20(#808)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🩹 Fixed dimension mismatching of CoherentArtifactMegacomplex If the dataset_model was index_dependent, CoherentArtifactMegacomplex.finalize_data did crash because it tried cast data of dimensions (global_dimension, model_dimension, "coherent_artifact_order") to the dimensions (model_dimension, "coherent_artifact_order") * ♻️ Refactored test_coherent_artifact to be more readable * 🧪 Parametrized test_coherent_artifact to cover index dependent case * ♻️ Renamed 'DatasetModel.index_dependent' ->'is_index_dependent' (see #755) * 🔧 Let mypy run normally in pre-commit * 👌 Renamed 'coherent_artifact_concentration' to 'coherent_artifact_response' Ref: https://github.com/glotaran/pyglotaran/pull/808#discussion_r706669563 --- .pre-commit-config.yaml | 2 - glotaran/analysis/problem.py | 2 +- glotaran/analysis/problem_grouped.py | 2 +- glotaran/analysis/problem_ungrouped.py | 10 +-- glotaran/analysis/simulation.py | 4 +- glotaran/analysis/test/test_optimization.py | 14 ++-- .../coherent_artifact_megacomplex.py | 7 +- .../test/test_coherent_artifact.py | 65 ++++++++++++------- glotaran/model/dataset_model.py | 2 +- glotaran/model/dataset_model.pyi | 2 +- 10 files changed, 65 insertions(+), 45 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f1d4d407..b3ac808be 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -96,8 +96,6 @@ repos: - id: mypy files: "^glotaran/(plugin_system|utils|deprecation)" exclude: "docs" - args: [glotaran] - pass_filenames: false additional_dependencies: [types-all] - repo: https://github.com/econchick/interrogate diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index bc7463499..4d10d4c31 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -340,7 +340,7 @@ def create_result_dataset(self, label: str, copy: bool = True) -> xr.Dataset: model_dimension = dataset_model.get_model_dimension() if copy: dataset = dataset.copy() - if dataset_model.index_dependent(): + if dataset_model.is_index_dependent(): dataset = self.create_index_dependent_result_dataset(label, dataset) else: dataset = self.create_index_independent_result_dataset(label, dataset) diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index e898970b2..e26f33bea 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -48,7 +48,7 @@ def __init__(self, scheme: Scheme): raise ValueError( f"Cannot group datasets. Model dimension '{model_dimensions}' do not match." ) - self._index_dependent = any(d.index_dependent() for d in self.dataset_models.values()) + self._index_dependent = any(d.is_index_dependent() for d in self.dataset_models.values()) self._global_dimension = global_dimensions.pop() self._model_dimension = model_dimensions.pop() self._group_clp_labels = None diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 8b7fedb18..8afd5f164 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -59,7 +59,7 @@ def calculate_matrices( for label, dataset_model in self.dataset_models.items(): - if dataset_model.index_dependent(): + if dataset_model.is_index_dependent(): self._calculate_index_dependent_matrix(label, dataset_model) else: self._calculate_index_independent_matrix(label, dataset_model) @@ -132,10 +132,10 @@ def _calculate_residual(self, label: str, dataset_model: DatasetModel): for i, index in enumerate(global_axis): reduced_clp_labels, reduced_matrix = ( self.reduced_matrices[label][i] - if dataset_model.index_dependent() + if dataset_model.is_index_dependent() else self.reduced_matrices[label] ) - if not dataset_model.index_dependent(): + if not dataset_model.is_index_dependent(): reduced_matrix = reduced_matrix.copy() if dataset_model.scale is not None: @@ -183,7 +183,7 @@ def _calculate_full_model_residual(self, label: str, dataset_model: DatasetModel model_matrix = self.matrices[label] global_matrix = self.global_matrices[label].matrix - if dataset_model.index_dependent(): + if dataset_model.is_index_dependent(): matrix = np.concatenate( [ np.kron(global_matrix[i, :], model_matrix[i].matrix) @@ -205,7 +205,7 @@ def _calculate_full_model_residual(self, label: str, dataset_model: DatasetModel def _get_clp_labels(self, label: str, index: int = 0): return ( self.matrices[label][index].clp_labels - if self.dataset_models[label].index_dependent() + if self.dataset_models[label].is_index_dependent() else self.matrices[label].clp_labels ) diff --git a/glotaran/analysis/simulation.py b/glotaran/analysis/simulation.py index 2bbdc0b00..22ebe0a38 100644 --- a/glotaran/analysis/simulation.py +++ b/glotaran/analysis/simulation.py @@ -93,7 +93,7 @@ def simulate_clp( ) for index, _ in enumerate(global_axis) ] - if dataset_model.index_dependent() + if dataset_model.is_index_dependent() else calculate_matrix(dataset_model, {}) ) @@ -108,7 +108,7 @@ def simulate_clp( ) result = result.to_dataset(name="data") for i in range(global_axis.size): - index_matrix = matrices[i] if dataset_model.index_dependent() else matrices + index_matrix = matrices[i] if dataset_model.is_index_dependent() else matrices result.data[:, i] = np.dot( index_matrix.matrix, clp.isel({global_dimension: i}).sel({"clp_label": index_matrix.clp_labels}), diff --git a/glotaran/analysis/test/test_optimization.py b/glotaran/analysis/test/test_optimization.py index 2576b02e3..2a8e97ea7 100644 --- a/glotaran/analysis/test/test_optimization.py +++ b/glotaran/analysis/test/test_optimization.py @@ -12,7 +12,7 @@ from glotaran.project import Scheme -@pytest.mark.parametrize("index_dependent", [True, False]) +@pytest.mark.parametrize("is_index_dependent", [True, False]) @pytest.mark.parametrize("grouped", [True, False]) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize( @@ -27,16 +27,16 @@ "suite", [OneCompartmentDecay, TwoCompartmentDecay, ThreeDatasetDecay, MultichannelMulticomponentDecay], ) -def test_optimization(suite, index_dependent, grouped, weight, method): +def test_optimization(suite, is_index_dependent, grouped, weight, method): model = suite.model - model.megacomplex["m1"].is_index_dependent = index_dependent + model.megacomplex["m1"].is_index_dependent = is_index_dependent print("Grouped:", grouped) - print("Index dependent:", index_dependent) + print("Index dependent:", is_index_dependent) sim_model = suite.sim_model - sim_model.megacomplex["m1"].is_index_dependent = index_dependent + sim_model.megacomplex["m1"].is_index_dependent = is_index_dependent print(model.validate()) assert model.valid() @@ -54,8 +54,8 @@ def test_optimization(suite, index_dependent, grouped, weight, method): print(model.validate(initial_parameters)) assert model.valid(initial_parameters) assert ( - model.dataset["dataset1"].fill(model, initial_parameters).index_dependent() - == index_dependent + model.dataset["dataset1"].fill(model, initial_parameters).is_index_dependent() + == is_index_dependent ) nr_datasets = 3 if issubclass(suite, ThreeDatasetDecay) else 1 diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index a1aa8a2d8..2c81cf93e 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -73,8 +73,11 @@ def finalize_data( global_dimension = dataset_model.get_global_dimension() model_dimension = dataset_model.get_model_dimension() dataset.coords["coherent_artifact_order"] = np.arange(1, self.order + 1) - dataset["coherent_artifact_concentration"] = ( - (model_dimension, "coherent_artifact_order"), + response_dimensions = (model_dimension, "coherent_artifact_order") + if dataset_model.is_index_dependent() is True: + response_dimensions = (global_dimension, *response_dimensions) + dataset["coherent_artifact_response"] = ( + response_dimensions, dataset.matrix.sel(clp_label=self.compartments()).values, ) dataset["coherent_artifact_associated_spectra"] = ( diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py index 9a02ce4aa..6e0dcfb0d 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import xarray as xr from glotaran.analysis.optimize import optimize @@ -11,10 +12,14 @@ from glotaran.project import Scheme -def test_coherent_artifact(): +@pytest.mark.parametrize( + "is_index_dependent", + (True, False), +) +def test_coherent_artifact(is_index_dependent: bool): model_dict = { "initial_concentration": { - "j1": {"compartments": ["s1"], "parameters": ["2"]}, + "j1": {"compartments": ["s1"], "parameters": ["irf_center"]}, }, "megacomplex": { "mc1": {"type": "decay", "k_matrix": ["k1"]}, @@ -23,15 +28,15 @@ def test_coherent_artifact(): "k_matrix": { "k1": { "matrix": { - ("s1", "s1"): "1", + ("s1", "s1"): "rate", } } }, "irf": { "irf1": { - "type": "spectral-gaussian", - "center": "2", - "width": "3", + "type": "spectral-multi-gaussian", + "center": ["irf_center"], + "width": ["irf_width"], }, }, "dataset": { @@ -42,6 +47,24 @@ def test_coherent_artifact(): }, }, } + + parameter_list = [ + ["rate", 101e-4], + ["irf_center", 10, {"vary": False, "non-negative": False}], + ["irf_width", 20, {"vary": False, "non-negative": False}], + ] + + if is_index_dependent is True: + irf_spec = model_dict["irf"]["irf1"] + irf_spec["dispersion_center"] = "irf_dispc" + irf_spec["center_dispersion"] = ["irf_disp1", "irf_disp2"] + + parameter_list += [ + ["irf_dispc", 300, {"vary": False, "non-negative": False}], + ["irf_disp1", 0.01, {"vary": False, "non-negative": False}], + ["irf_disp2", 0.001, {"vary": False, "non-negative": False}], + ] + model = Model.from_dict( model_dict.copy(), megacomplex_types={ @@ -50,23 +73,16 @@ def test_coherent_artifact(): }, ) - parameters = ParameterGroup.from_list( - [ - 101e-4, - [10, {"vary": False, "non-negative": False}], - [20, {"vary": False, "non-negative": False}], - [30, {"vary": False, "non-negative": False}], - ] - ) + parameters = ParameterGroup.from_list(parameter_list) time = np.arange(0, 50, 1.5) - spectral = np.asarray([0]) + spectral = np.asarray([200, 300, 400]) coords = {"time": time, "spectral": spectral} dataset_model = model.dataset["dataset1"].fill(model, parameters) dataset_model.overwrite_global_dimension("spectral") dataset_model.set_coordinates(coords) - matrix = calculate_matrix(dataset_model, {}) + matrix = calculate_matrix(dataset_model, {"spectral": [0, 1, 2]}) compartments = matrix.clp_labels print(compartments) @@ -77,9 +93,9 @@ def test_coherent_artifact(): assert matrix.matrix.shape == (time.size, 4) clp = xr.DataArray( - [[1, 1, 1, 1]], + np.ones((3, 4)), coords=[ - ("spectral", [0]), + ("spectral", spectral), ( "clp_label", [ @@ -102,17 +118,20 @@ def test_coherent_artifact(): print(result.optimized_parameters) for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, parameters.get(label).value, rtol=1e-8) resultdata = result.data["dataset1"] assert np.array_equal(data.time, resultdata.time) assert np.array_equal(data.spectral, resultdata.spectral) assert data.data.shape == resultdata.data.shape assert data.data.shape == resultdata.fitted_data.shape - assert np.allclose(data.data, resultdata.fitted_data, rtol=1e-2) + assert np.allclose(data.data, resultdata.fitted_data) - assert "coherent_artifact_concentration" in resultdata - assert resultdata["coherent_artifact_concentration"].shape == (time.size, 3) + assert "coherent_artifact_response" in resultdata + if is_index_dependent: + assert resultdata["coherent_artifact_response"].shape == (spectral.size, time.size, 3) + else: + assert resultdata["coherent_artifact_response"].shape == (time.size, 3) assert "coherent_artifact_associated_spectra" in resultdata - assert resultdata["coherent_artifact_associated_spectra"].shape == (1, 3) + assert resultdata["coherent_artifact_associated_spectra"].shape == (3, 3) diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index 2290ff3c9..443de05ae 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -140,7 +140,7 @@ def get_weight(self) -> np.ndarray | None: """Gets the dataset model's weight.""" return self._weight - def index_dependent(self) -> bool: + def is_index_dependent(self) -> bool: """Indicates if the dataset model is index dependent.""" if hasattr(self, "_index_dependent"): return self._index_dependent diff --git a/glotaran/model/dataset_model.pyi b/glotaran/model/dataset_model.pyi index c0a7b49f7..96cd84ed0 100644 --- a/glotaran/model/dataset_model.pyi +++ b/glotaran/model/dataset_model.pyi @@ -37,7 +37,7 @@ class DatasetModel: def set_data(self, dataset: xr.Dataset) -> DatasetModel: ... def get_data(self) -> np.ndarray: ... def get_weight(self) -> np.ndarray | None: ... - def index_dependent(self) -> bool: ... + def is_index_dependent(self) -> bool: ... def overwrite_index_dependent(self, index_dependent: bool): ... def has_global_model(self) -> bool: ... def set_coordinates(self, coords: dict[str, np.ndarray]): ... From 0f36985766dad6bc25ad0846024b5f18b654d86f Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 12 Sep 2021 00:05:14 +0200 Subject: [PATCH 28/29] =?UTF-8?q?=F0=9F=94=A7=20Added=20sourcery=20config?= =?UTF-8?q?=20file=20(#811)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added 'simplify-boolean-comparison' to ignore 'if something is True' prevents bugs where something is a callable, where 'if something():' would evaluate to False but 'if something:' evaluated to true since the callable object exists. --- .sourcery.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .sourcery.yaml diff --git a/.sourcery.yaml b/.sourcery.yaml new file mode 100644 index 000000000..53a242ddf --- /dev/null +++ b/.sourcery.yaml @@ -0,0 +1,18 @@ +ignore: [] + +refactor: + skip: [simplify-boolean-comparison] + +metrics: + quality_threshold: 25.0 + +clone_detection: + min_lines: 3 + min_duplicates: 2 + identical_clones_only: false + +github: + labels: [] + ignore_labels: [sourcery-ignore] + request_review: author + sourcery_branch: sourcery/{base_branch} From af7addca01fbf9ac32053f75f809d7ea7fbb9382 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Thu, 16 Sep 2021 02:10:26 +0200 Subject: [PATCH 29/29] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20model.fro?= =?UTF-8?q?m=5Fdict=20to=20parse=20megacomplex=5Ftype=20from=20dict=20and?= =?UTF-8?q?=20add=20simple=5Fgenerator=20for=20testing=20(#807)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 👌 Refactor model.from_dict to use kwargs only for override; The keyword arguments megacomplex_types and default_megacomplex_type are only used for overwrites in testing * ♻️ Refactor a simple DecayModel used in testing * ✨🧪 Added plugin registry monkeypatch context managers for testing * 🔧Add rich dependency (Amazing for debugging and printing complex objects) * 🐛 Fix ThreeComponentParallel to be actually parallel * 🔧🩹 Fixed pydocstyle and darglint not picking up the testing module * 🩹Adapted pytest benchmarks to new monkeypatch_plugin_registry * 🩹 Ignore Codacy issue W0622 redefining built-in `print` when using `from rich import print` Co-authored-by: Sebastian Weigand Co-authored-by: Sourcery AI <> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 10 +- benchmark/pytest/analysis/test_problem.py | 2 + glotaran/builtin/io/yml/yml.py | 14 +- .../decay/test/test_decay_megacomplex.py | 80 +---- glotaran/model/model.py | 24 +- glotaran/testing/__init__.py | 1 + glotaran/testing/model_generators.py | 311 ++++++++++++++++++ glotaran/testing/plugin_system.py | 184 +++++++++++ glotaran/testing/test/__init__.py | 0 .../testing/test/test_model_generators.py | 119 +++++++ glotaran/testing/test/test_plugin_system.py | 91 +++++ requirements_dev.txt | 1 + setup.cfg | 1 + 13 files changed, 748 insertions(+), 90 deletions(-) create mode 100644 glotaran/testing/__init__.py create mode 100644 glotaran/testing/model_generators.py create mode 100644 glotaran/testing/plugin_system.py create mode 100644 glotaran/testing/test/__init__.py create mode 100644 glotaran/testing/test/test_model_generators.py create mode 100644 glotaran/testing/test/test_plugin_system.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3ac808be..2478c3ee0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,8 +77,8 @@ repos: rev: 6.1.1 hooks: - id: pydocstyle - files: "^glotaran/(plugin_system|utils|deprecation)" - exclude: "docs|tests?" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" + exclude: "docs|tests?/" # this is needed due to the following issue: # https://github.com/PyCQA/pydocstyle/issues/368 args: [--ignore-decorators=wrap_func_as_method] @@ -87,14 +87,14 @@ repos: rev: v1.8.0 hooks: - id: darglint - files: "^glotaran/(plugin_system|utils|deprecation)" - exclude: "docs|tests?" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" + exclude: "docs|tests?/" - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910 hooks: - id: mypy - files: "^glotaran/(plugin_system|utils|deprecation)" + files: "^glotaran/(plugin_system|utils|deprecation|testing)" exclude: "docs" additional_dependencies: [types-all] diff --git a/benchmark/pytest/analysis/test_problem.py b/benchmark/pytest/analysis/test_problem.py index 9c148818f..16aa886c3 100644 --- a/benchmark/pytest/analysis/test_problem.py +++ b/benchmark/pytest/analysis/test_problem.py @@ -13,6 +13,7 @@ from glotaran.model import megacomplex from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +from glotaran.testing.plugin_system import monkeypatch_plugin_registry if TYPE_CHECKING: from glotaran.model import DatasetModel @@ -53,6 +54,7 @@ def finalize_data( pass +@monkeypatch_plugin_registry(test_megacomplex={"benchmark": BenchmarkMegacomplex}) def setup_model(index_dependent): model_dict = { "megacomplex": {"m1": {"is_index_dependent": index_dependent}}, diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index e45ad6cca..540c59824 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -16,7 +16,6 @@ from glotaran.io import save_dataset from glotaran.io import save_parameters from glotaran.model import Model -from glotaran.model import get_megacomplex from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme @@ -66,18 +65,7 @@ def load_model(self, file_name: str) -> Model: if "megacomplex" not in spec: raise ValueError("No megacomplex defined in model") - megacomplex_types = { - m["type"]: get_megacomplex(m["type"]) - for m in spec["megacomplex"].values() - if "type" in m - } - if default_megacomplex is not None: - megacomplex_types[default_megacomplex] = get_megacomplex(default_megacomplex) - del spec["default-megacomplex"] - - return Model.from_dict( - spec, megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex - ) + return Model.from_dict(spec, megacomplex_types=None, default_megacomplex_type=None) def load_parameters(self, file_name: str) -> ParameterGroup: diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index f788a6b90..939fb6c1b 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -6,11 +6,10 @@ from glotaran.analysis.optimize import optimize from glotaran.analysis.simulation import simulate -from glotaran.builtin.megacomplexes.decay import DecayMegacomplex -from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Scheme +from glotaran.testing.model_generators import SimpleModelGenerator def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): @@ -28,20 +27,9 @@ class DecayModel(Model): def from_dict( cls, model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, ): - defaults: dict[str, type[Megacomplex]] = { - "decay": DecayMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) + model_dict = {**model_dict, "default-megacomplex": "decay"} + return super().from_dict(model_dict) class OneComponentOneChannel: @@ -136,62 +124,16 @@ class OneComponentOneChannelGaussianIrf: class ThreeComponentParallel: - model = DecayModel.from_dict( - { - "initial_concentration": { - "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["j.1", "j.1", "j.1"]}, - }, - "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s2", "s1"): "kinetic.1", - ("s3", "s2"): "kinetic.2", - ("s3", "s3"): "kinetic.3", - } - } - }, - "irf": { - "irf1": { - "type": "multi-gaussian", - "center": ["irf.center"], - "width": ["irf.width"], - }, - }, - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "irf": "irf1", - "megacomplex": ["mc1"], - }, - }, - } + generator = SimpleModelGenerator( + rates=[300e-3, 500e-4, 700e-5], + irf={"center": 1.3, "width": 7.8}, + k_matrix="parallel", ) + model, initial_parameters = generator.model_and_parameters + + generator.rates = [301e-3, 502e-4, 705e-5] + wanted_parameters = generator.parameters - initial_parameters = ParameterGroup.from_dict( - { - "kinetic": [ - ["1", 300e-3], - ["2", 500e-4], - ["3", 700e-5], - ], - "irf": [["center", 1.3], ["width", 7.8]], - "j": [["1", 1, {"vary": False, "non-negative": False}]], - } - ) - wanted_parameters = ParameterGroup.from_dict( - { - "kinetic": [ - ["1", 301e-3], - ["2", 502e-4], - ["3", 705e-5], - ], - "irf": [["center", 1.3], ["width", 7.8]], - "j": [["1", 1, {"vary": False, "non-negative": False}]], - } - ) time = np.arange(-10, 100, 1.5) pixel = np.arange(600, 750, 10) diff --git a/glotaran/model/model.py b/glotaran/model/model.py index 13c7cf24c..8d56c509c 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy +from typing import Any from typing import List from warnings import warn @@ -17,6 +18,7 @@ from glotaran.model.weight import Weight from glotaran.parameter import Parameter from glotaran.parameter import ParameterGroup +from glotaran.plugin_system.megacomplex_registration import get_megacomplex from glotaran.utils.ipython import MarkdownStr default_model_items = { @@ -56,18 +58,34 @@ def __init__( @classmethod def from_dict( cls, - model_dict: dict, + model_dict: dict[str, Any], *, - megacomplex_types: dict[str, type[Megacomplex]], + megacomplex_types: dict[str, type[Megacomplex]] | None = None, default_megacomplex_type: str | None = None, ) -> Model: """Creates a model from a dictionary. Parameters ---------- - model_dict : + model_dict: dict[str, Any] Dictionary containing the model. + megacomplex_types: dict[str, type[Megacomplex]] | None + Overwrite 'megacomplex_types' in ``model_dict`` for testing. + default_megacomplex_type: str | None + Overwrite 'default-megacomplex' in ``model_dict`` for testing. """ + if default_megacomplex_type is None: + default_megacomplex_type = model_dict.get("default-megacomplex") + + if megacomplex_types is None: + megacomplex_types = { + m["type"]: get_megacomplex(m["type"]) + for m in model_dict["megacomplex"].values() + if "type" in m + } + if default_megacomplex_type is not None: + megacomplex_types[default_megacomplex_type] = get_megacomplex(default_megacomplex_type) + model_dict.pop("default-megacomplex", None) model = cls( megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type diff --git a/glotaran/testing/__init__.py b/glotaran/testing/__init__.py new file mode 100644 index 000000000..a2b929333 --- /dev/null +++ b/glotaran/testing/__init__.py @@ -0,0 +1 @@ +"""Testing framework package for glotaran itself and plugins.""" diff --git a/glotaran/testing/model_generators.py b/glotaran/testing/model_generators.py new file mode 100644 index 000000000..5c59762b3 --- /dev/null +++ b/glotaran/testing/model_generators.py @@ -0,0 +1,311 @@ +"""Model generators used to generate simple models from a set of inputs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Literal + +from glotaran.model import Model +from glotaran.parameter.parameter_group import ParameterGroup + +if TYPE_CHECKING: + from glotaran.utils.ipython import MarkdownStr + + +def _split_iterable_in_non_dict_and_dict_items( + input_list: list[float, dict[str, bool | float]], +) -> tuple[list[float], list[dict[str, bool | float]]]: + """Split an iterable (list) into non-dict and dict items. + + Parameters + ---------- + input_list : list[float, dict[str, bool | float]] + A list of values of type `float` and a dict with parameter options, e.g. + `[1, 2, 3, {"vary": False, "non-negative": True}]` + + Returns + ------- + tuple[list[float], list[dict[str, bool | float]]] + Split a list into non-dict (`values`) and dict items (`defaults`), + return a tuple (`values`, `defaults`) + """ + values: list = [val for val in input_list if not isinstance(val, dict)] + defaults: list = [val for val in input_list if isinstance(val, dict)] + return values, defaults + + +@dataclass +class SimpleModelGenerator: + """A minimal boilerplate model and parameters generator. + + Generates a model (together with the parameters specification) based on + parameter input values assigned to the generator's attributes + """ + + rates: list[float] = field(default_factory=list) + """A list of values representing decay rates""" + k_matrix: Literal["parallel", "sequential"] | dict[tuple[str, str], str] = "parallel" + """"A `dict` with a k_matrix specification or `Literal["parallel", "sequential"]`""" + compartments: list[str] | None = None + """A list of compartment names""" + irf: dict[str, float] = field(default_factory=dict) + """A dict of items specifying an irf""" + initial_concentration: list[float] = field(default_factory=list) + """A list values representing the initial concentration""" + dispersion_coefficients: list[float] = field(default_factory=list) + """A list of values representing the dispersion coefficients""" + dispersion_center: float | None = None + """A value representing the dispersion center""" + default_megacomplex: str = "decay" + """The default_megacomplex identifier""" + # TODO: add support for a spectral model: + # shapes: list[float] = field(default_factory=list, init=False) + + @property + def valid(self) -> bool: + """Check if the generator state is valid. + + Returns + ------- + bool + Generator state obtained by calling the generated model's + `valid` function with the generated parameters as input. + """ + try: + return self.model.valid(parameters=self.parameters) + except ValueError: + return False + + def validate(self) -> str: + """Call `validate` on the generated model and return its output. + + Returns + ------- + str + A string listing problems in the generated model and parameters if any. + """ + return self.model.validate(parameters=self.parameters) + + @property + def model(self) -> Model: + """Return the generated model. + + Returns + ------- + Model + The generated model of type :class:`glotaran.model.Model`. + """ + return Model.from_dict(self.model_dict) + + @property + def model_dict(self) -> dict: + """Return a dict representation of the generated model. + + Returns + ------- + dict + A dict representation of the generated model. + """ + return self._model_dict() + + @property + def parameters(self) -> ParameterGroup: + """Return the generated parameters of type :class:`glotaran.parameter.ParameterGroup`. + + Returns + ------- + ParameterGroup + The generated parameters of type of type :class:`glotaran.parameter.ParameterGroup`. + """ + return ParameterGroup.from_dict(self.parameters_dict) + + @property + def parameters_dict(self) -> dict: + """Return a dict representation of the generated parameters. + + Returns + ------- + dict + A dict representing the generated parameters. + """ + return self._parameters_dict() + + @property + def model_and_parameters(self) -> tuple[Model, ParameterGroup]: + """Return generated model and parameters. + + Returns + ------- + tuple[Model, ParameterGroup] + A model of type :class:`glotaran.model.Model` and + and parameters of type :class:`glotaran.parameter.ParameterGroup`. + """ + return self.model, self.parameters + + @property + def _rates(self) -> tuple[list[float], list[dict[str, bool | float]]]: + """Validate input to rates, return a tuple of rates and parameter defaults. + + Returns + ------- + tuple[list[float], list[dict[str, bool | float]]] + A tuple of a list of rates and a dict containing parameter defaults + + Raises + ------ + ValueError + Raised if rates is not a list of at least one number. + """ + if not isinstance(self.rates, list): + raise ValueError(f"generator.rates: must be a `list`, got: {self.rates}") + if len(self.rates) == 0: + raise ValueError("generator.rates: must be a `list` with 1 or more rates") + if not isinstance(self.rates[0], (int, float)): + raise ValueError(f"generator.rates: 1st element must be numeric, got: {self.rates[0]}") + return _split_iterable_in_non_dict_and_dict_items(self.rates) + + def _parameters_dict_items(self) -> dict: + """Return a dict with items used in constructing the parameters. + + Returns + ------- + dict + A dict with items used in constructing a parameters dict. + """ + rates, rates_defaults = self._rates + items = {"rates": rates} + if rates_defaults: + items.update({"rates_defaults": rates_defaults[0]}) + items.update({"irf": [[key, value] for key, value in self.irf.items()]}) + if self.initial_concentration: + items.update({"inputs": self.initial_concentration}) + elif self.k_matrix == "parallel": + items.update( + { + "inputs": [ + ["1", 1], + {"vary": False}, + ] + } + ) + elif self.k_matrix == "sequential": + items.update( + { + "inputs": [ + ["1", 1], + ["0", 0], + {"vary": False}, + ] + } + ) + return items + + def _model_dict_items(self) -> dict: + """Return a dict with items used in constructing the model. + + Returns + ------- + dict + A dict with items used in constructing a model dict. + """ + rates, _ = self._rates + nr = len(rates) + indices = list(range(1, 1 + nr)) + items = {"default-megacomplex": self.default_megacomplex} + if self.irf: + items.update( + { + "irf": { + "type": "multi-gaussian", + "center": ["irf.center"], + "width": ["irf.width"], + } + } + ) + if isinstance(self.k_matrix, dict): + items.update({"k_matrix": self.k_matrix}) + items.update({"input_parameters": [f"inputs.{i}" for i in indices]}) + items.update({"compartments": [f"s{i}" for i in indices]}) + # TODO: get unique compartments from user defined k_matrix + if self.k_matrix == "parallel": + items.update({"input_parameters": ["inputs.1"] * nr}) + items.update({"k_matrix": {(f"s{i}", f"s{i}"): f"rates.{i}" for i in indices}}) + elif self.k_matrix == "sequential": + items.update({"input_parameters": ["inputs.1"] + ["inputs.0"] * (nr - 1)}) + items.update( + {"k_matrix": {(f"s{i if i==nr else i+1}", f"s{i}"): f"rates.{i}" for i in indices}} + ) + if self.k_matrix in ("parallel", "sequential"): + items.update({"compartments": [f"s{i}" for i in indices]}) + return items + + def _parameters_dict(self) -> dict: + """Return a parameters dict. + + Returns + ------- + dict + A dict that can be passed to the `ParameterGroup` `from_dict` method. + """ + items = self._parameters_dict_items() + rates = items["rates"] + if "rates_defaults" in items: + rates += [items["rates_defaults"]] + result = {"rates": rates} + if items["irf"]: + result.update({"irf": items["irf"]}) + result.update({"inputs": items["inputs"]}) + return result + + def _model_dict(self) -> dict: + """Return a model dict. + + Returns + ------- + dict + A dict that can be passed to the `Model` `from_dict` method. + """ + items = self._model_dict_items() + result = {"default-megacomplex": items["default-megacomplex"]} + result.update( + { + "initial_concentration": { + "j1": { + "compartments": items["compartments"], + "parameters": items["input_parameters"], + }, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": {"k1": {"matrix": items["k_matrix"]}}, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "megacomplex": ["mc1"], + }, + }, + } + ) + if "irf" in items: + result["dataset"]["dataset1"].update({"irf": "irf1"}) + result.update( + { + "irf": { + "irf1": items["irf"], + } + } + ) + return result + + def markdown(self) -> MarkdownStr: + """Return a markdown string representation of the generated model and parameters. + + Returns + ------- + MarkdownStr + A markdown string + """ + return self.model.markdown(parameters=self.parameters) diff --git a/glotaran/testing/plugin_system.py b/glotaran/testing/plugin_system.py new file mode 100644 index 000000000..1487eb1cd --- /dev/null +++ b/glotaran/testing/plugin_system.py @@ -0,0 +1,184 @@ +"""Mock functionality for the plugin system.""" +from __future__ import annotations + +from contextlib import ExitStack +from contextlib import contextmanager +from typing import TYPE_CHECKING +from unittest import mock + +from glotaran.plugin_system.base_registry import __PluginRegistry + +if TYPE_CHECKING: + from typing import Generator + from typing import MutableMapping + + from glotaran.io.interface import DataIoInterface + from glotaran.io.interface import ProjectIoInterface + from glotaran.model.megacomplex import Megacomplex + from glotaran.plugin_system.base_registry import _PluginType + + +@contextmanager +def _monkeypatch_plugin_registry( + register_name: str, + test_registry: MutableMapping[str, _PluginType] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Contextmanager to monkeypatch any Pluginregistry with name ``register_name``. + + Parameters + ---------- + register_name : str + Name of the register which should be patched. + test_registry : MutableMapping[str, _PluginType] + Registry to to update or replace the ``register_name`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_registry`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + + See Also + -------- + monkeypatch_plugin_registry_megacomplex + monkeypatch_plugin_registry_data_io + monkeypatch_plugin_registry_project_io + """ + if test_registry is not None: + initila_plugins = ( + __PluginRegistry.__dict__[register_name] if not create_new_registry else {} + ) + + with mock.patch.object( + __PluginRegistry, register_name, {**initila_plugins, **test_registry} + ): + yield + else: + yield + + +@contextmanager +def monkeypatch_plugin_registry_megacomplex( + test_megacomplex: MutableMapping[str, type[Megacomplex]] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`Megacomplex` registry. + + Parameters + ---------- + test_megacomplex : MutableMapping[str, type[Megacomplex]], optional + Registry to to update or replace the ``Megacomplex`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_megacomplex`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("megacomplex", test_megacomplex, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry_data_io( + test_data_io: MutableMapping[str, DataIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`DataIoInterface` registry. + + Parameters + ---------- + test_data_io : MutableMapping[str, DataIoInterface], optional + Registry to to update or replace the ``DataIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_data_io`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("data_io", test_data_io, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry_project_io( + test_project_io: MutableMapping[str, ProjectIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Monkeypatch the :class:`ProjectIoInterface` registry. + + Parameters + ---------- + test_project_io : MutableMapping[str, ProjectIoInterface], optional + Registry to to update or replace the ``ProjectIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from ``test_data_io`` + , by default False + + Yields + ------ + Generator[None, None, None] + Just to keep the context alive. + """ + with _monkeypatch_plugin_registry("project_io", test_project_io, create_new_registry): + yield + + +@contextmanager +def monkeypatch_plugin_registry( + *, + test_megacomplex: MutableMapping[str, type[Megacomplex]] | None = None, + test_data_io: MutableMapping[str, DataIoInterface] | None = None, + test_project_io: MutableMapping[str, ProjectIoInterface] | None = None, + create_new_registry: bool = False, +) -> Generator[None, None, None]: + """Contextmanager to monkeypatch multiple plugin registries at once. + + Parameters + ---------- + test_megacomplex : MutableMapping[str, type[Megacomplex]], optional + Registry to to update or replace the ``Megacomplex`` registry with. + , by default None + test_data_io : MutableMapping[str, DataIoInterface], optional + Registry to to update or replace the ``DataIoInterface`` registry with. + , by default None + test_project_io : MutableMapping[str, ProjectIoInterface], optional + Registry to to update or replace the ``ProjectIoInterface`` registry with. + , by default None + create_new_registry : bool + Whether to update the actual registry or create a new one from the arguments. + , by default False + + Yields + ------ + Generator[None, None, None] + Just keeps all context manager alive + + See Also + -------- + monkeypatch_plugin_registry_megacomplex + monkeypatch_plugin_registry_data_io + monkeypatch_plugin_registry_project_io + """ + context_managers = [ + monkeypatch_plugin_registry_megacomplex(test_megacomplex, create_new_registry), + monkeypatch_plugin_registry_data_io(test_data_io, create_new_registry), + monkeypatch_plugin_registry_project_io(test_project_io, create_new_registry), + ] + + with ExitStack() as stack: + for context_manager in context_managers: + stack.enter_context(context_manager) + yield diff --git a/glotaran/testing/test/__init__.py b/glotaran/testing/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glotaran/testing/test/test_model_generators.py b/glotaran/testing/test/test_model_generators.py new file mode 100644 index 000000000..bd287f5cf --- /dev/null +++ b/glotaran/testing/test/test_model_generators.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from copy import deepcopy + +import pytest +from rich import pretty +from rich import print # pylint: disable=W0622 + +from glotaran.model import Model +from glotaran.parameter import ParameterGroup +from glotaran.testing.model_generators import SimpleModelGenerator + +pretty.install() + + +REF_PARAMETER_DICT = { + "rates": [ + ["1", 501e-3], + ["2", 202e-4], + ["3", 105e-5], + {"non-negative": True}, + ], + "irf": [["center", 1.3], ["width", 7.8]], + "inputs": [ + ["1", 1], + ["0", 0], + {"vary": False}, + ], +} + +REF_MODEL_DICT = { + "default-megacomplex": "decay", + "initial_concentration": { + "j1": { + "compartments": ["s1", "s2", "s3"], + "parameters": ["inputs.1", "inputs.0", "inputs.0"], + }, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "rates.1", + ("s3", "s2"): "rates.2", + ("s3", "s3"): "rates.3", + } + } + }, + "irf": { + "irf1": { + "type": "multi-gaussian", + "center": ["irf.center"], + "width": ["irf.width"], + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["mc1"], + }, + }, +} + + +def simple_diff_between_string(string1, string2): + return "".join(c2 for c1, c2 in zip(string1, string2) if c1 != c2) + + +def test_three_component_sequential_model(): + ref_model = Model.from_dict(deepcopy(REF_MODEL_DICT)) + ref_parameters = ParameterGroup.from_dict(deepcopy(REF_PARAMETER_DICT)) + generator = SimpleModelGenerator( + rates=[501e-3, 202e-4, 105e-5, {"non-negative": True}], + irf={"center": 1.3, "width": 7.8}, + k_matrix="sequential", + ) + for key, _ in REF_PARAMETER_DICT.items(): + assert key in generator.parameters_dict + # TODO: check contents + + model, parameters = generator.model_and_parameters + assert str(ref_model) == str(model), print( + simple_diff_between_string(str(model), str(ref_model)) + ) + assert str(ref_parameters) == str(parameters), print( + simple_diff_between_string(str(parameters), str(ref_parameters)) + ) + + +def test_only_rates_no_irf(): + generator = SimpleModelGenerator(rates=[0.1, 0.02, 0.003]) + assert "irf" not in generator.model_dict.keys() + + +def test_no_rates(): + generator = SimpleModelGenerator() + assert generator.valid is False + + +def test_one_rate(): + generator = SimpleModelGenerator([1]) + assert generator.valid is True + assert "is valid" in generator.validate() + + +def test_rates_not_a_list(): + generator = SimpleModelGenerator(1) + assert generator.valid is False + with pytest.raises(ValueError): + print(generator.validate()) + + +def test_set_rates_delayed(): + generator = SimpleModelGenerator() + generator.rates = [1, 2, 3] + assert generator.valid is True diff --git a/glotaran/testing/test/test_plugin_system.py b/glotaran/testing/test/test_plugin_system.py new file mode 100644 index 000000000..df1b2d9b7 --- /dev/null +++ b/glotaran/testing/test/test_plugin_system.py @@ -0,0 +1,91 @@ +import pytest + +from glotaran.io import DataIoInterface +from glotaran.io import ProjectIoInterface +from glotaran.model import Megacomplex +from glotaran.model import megacomplex +from glotaran.plugin_system.data_io_registration import known_data_formats +from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names +from glotaran.plugin_system.project_io_registration import known_project_formats +from glotaran.testing.plugin_system import monkeypatch_plugin_registry +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_data_io +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_megacomplex +from glotaran.testing.plugin_system import monkeypatch_plugin_registry_project_io + + +@megacomplex(dimension="test") +class DummyMegacomplex(Megacomplex): + pass + + +class DummyDataIo(DataIoInterface): + pass + + +class DummyProjectIo(ProjectIoInterface): + pass + + +def test_monkeypatch_megacomplexes(): + """Megacomplex only added to registry while context is entered.""" + with monkeypatch_plugin_registry_megacomplex(test_megacomplex={"test_mc": DummyMegacomplex}): + assert "test_mc" in known_megacomplex_names() + + assert "test_mc" not in known_megacomplex_names() + with monkeypatch_plugin_registry(test_megacomplex={"test_full": DummyMegacomplex}): + assert "test_full" in known_megacomplex_names() + + assert "test_full" not in known_megacomplex_names() + + +def test_monkeypatch_data_io(): + """DataIoInterface only added to registry while context is entered.""" + with monkeypatch_plugin_registry_data_io( + test_data_io={"test_dio": DummyDataIo(format_name="test")} + ): + assert "test_dio" in known_data_formats() + + assert "test_mc" not in known_data_formats() + + with monkeypatch_plugin_registry(test_data_io={"test_full": DummyDataIo(format_name="test")}): + assert "test_full" in known_data_formats() + + assert "test_full" not in known_data_formats() + + +def test_monkeypatch_project_io(): + """ProjectIoInterface only added to registry while context is entered.""" + with monkeypatch_plugin_registry_project_io( + test_project_io={"test_pio": DummyProjectIo(format_name="test")} + ): + assert "test_pio" in known_project_formats() + + assert "test_pio" not in known_megacomplex_names() + with monkeypatch_plugin_registry( + test_project_io={"test_full": DummyProjectIo(format_name="test")} + ): + assert "test_full" in known_project_formats() + + assert "test_full" not in known_project_formats() + + +@pytest.mark.parametrize("create_new_registry", (True, False)) +def test_monkeypatch_plugin_registry_full(create_new_registry: bool): + """Create a completely new registry.""" + + assert "decay" in known_megacomplex_names() + assert "yml" in known_project_formats() + assert "sdt" in known_data_formats() + + with monkeypatch_plugin_registry( + test_megacomplex={"test_mc": DummyMegacomplex}, + test_project_io={"test_pio": DummyProjectIo(format_name="test")}, + test_data_io={"test_dio": DummyDataIo(format_name="test")}, + create_new_registry=create_new_registry, + ): + assert "test_mc" in known_megacomplex_names() + assert "test_pio" in known_project_formats() + assert "test_dio" in known_data_formats() + assert ("decay" not in known_megacomplex_names()) is create_new_registry + assert ("yml" not in known_project_formats()) is create_new_registry + assert ("sdt" not in known_data_formats()) is create_new_registry diff --git a/requirements_dev.txt b/requirements_dev.txt index e63b6b521..6df755111 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,6 +7,7 @@ asteval==0.9.25 numpy==1.21.1 scipy==1.7.0 click==8.0.1 +rich==10.9.0 numba==0.53.1 pandas==1.3.1 pyyaml==5.4.1 diff --git a/setup.cfg b/setup.cfg index 977820588..022bd304e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = numpy>=1.20.0 pandas>=0.25.2 pyyaml>=5.2 + rich>=10.9.0 scipy>=1.3.2 sdtfile>=2020.8.3 setuptools>=41.2