diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cb8f946a..9470afe2c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -119,6 +119,7 @@ repos: - id: interrogate args: [-vv, --config=pyproject.toml, glotaran] pass_filenames: false + additional_dependencies: [click<8] - repo: https://github.com/myint/rstcheck rev: "3f92957478422df87bd730abde66f089cc1ee19b" diff --git a/changelog.md b/changelog.md index 8937d3cbd..75d5488f2 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ - ✨ Python 3.10 support (#977) - ✨ Add simple decay megacomplexes (#860) - ✨ Feature: Generators (#866) +- ✨ Add clp guidance megacomplex (#1029) ### 👌 Minor Improvements: diff --git a/glotaran/analysis/optimization_group.py b/glotaran/analysis/optimization_group.py index 68828c821..0b1748965 100644 --- a/glotaran/analysis/optimization_group.py +++ b/glotaran/analysis/optimization_group.py @@ -339,6 +339,8 @@ def create_result_dataset( dataset_model = self.dataset_models[label] global_dimension = dataset_model.get_global_dimension() model_dimension = dataset_model.get_model_dimension() + dataset.attrs["global_dimension"] = global_dimension + dataset.attrs["model_dimension"] = model_dimension if copy: dataset = dataset.copy() if dataset_model.is_index_dependent(): diff --git a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py index 9bceecde5..a396c36dd 100644 --- a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py +++ b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py @@ -2,8 +2,8 @@ import os.path import re -import warnings from enum import Enum +from warnings import warn import numpy as np import pandas as pd @@ -13,8 +13,6 @@ from glotaran.io import register_data_io from glotaran.io.prepare_dataset import prepare_time_trace_dataset -# from glotaran.io.reader import file_reader - class DataFileType(Enum): time_explicit = "Time explicit" @@ -27,7 +25,7 @@ class ExplicitFile: """ # TODO: implement time_intervals - def __init__(self, filepath: str = None, dataset: xr.DataArray = None): + def __init__(self, filepath: str | None = None, dataset: xr.DataArray | None = None): self._file_data_format = None self._observations = [] # TODO: choose name: data_points, observations, data self._times = [] @@ -76,22 +74,19 @@ def write( if os.path.isfile(self._file) and not overwrite: raise FileExistsError(f"File already exist:\n{self._file}") - comment = self._comment + " " + comment + comment = f"{self._comment} {comment}" - comments = "# Filename: " + str(self._file) + "\n" + " ".join(comment.splitlines()) + "\n" + comments = f"# Filename: {str(self._file)}\n{' '.join(comment.splitlines())}\n" if file_format == DataFileType.wavelength_explicit: wav = "\t".join(repr(num) for num in self._spectral_indices) header = ( - comments + "Wavelength explicit\nIntervalnr {}" - "".format(len(self._spectral_indices)) + "\n" + wav + f"{comments}Wavelength explicit\nIntervalnr {len(self._spectral_indices)}\n{wav}" ) raw_data = np.vstack((self._times.T, self._observations)).T elif file_format == DataFileType.time_explicit: tim = "\t".join(repr(num) for num in self._times) - header = ( - comments + "Time explicit\nIntervalnr {}" "".format(len(self._times)) + "\n" + tim - ) + header = f"{comments}Time explicit\nIntervalnr {len(self._times)}\n{tim}" raw_data = np.vstack((self._spectral_indices.T, self._observations.T)).T else: raise NotImplementedError @@ -109,7 +104,7 @@ def write( def read(self, prepare: bool = True): if not os.path.isfile(self._file): - raise Exception("File does not exist.") + raise FileNotFoundError("File does not exist.") with open(self._file) as f: f.readline() # The first two lines are comments f.readline() @@ -221,7 +216,7 @@ def get_interval_number(line): try: interval_number = int(interval_number) except ValueError: - warnings.warn(f"No interval number found in line:\n{line}") + warn(f"No interval number found in line:\n{line}") interval_number = None return interval_number @@ -242,7 +237,7 @@ def get_data_file_format(line): # @file_reader(extension="ascii", name="Wavelength-/Time-Explicit ASCII") @register_data_io("ascii") class AsciiDataIo(DataIoInterface): - def load_dataset(self, file_name: str) -> xr.Dataset | xr.DataArray: + def load_dataset(self, file_name: str, *, prepare: bool = True) -> xr.Dataset | xr.DataArray: """Reads an ascii file in wavelength- or time-explicit format. See [1]_ for documentation of this format. @@ -272,17 +267,28 @@ def load_dataset(self, file_name: str) -> xr.Dataset | xr.DataArray: else TimeExplicitFile(file_name) ) - return data_file.read(prepare=True) + return data_file.read(prepare=prepare) def save_dataset( self, - dataset: xr.DataArray, + dataset: xr.DataArray | xr.Dataset, file_name: str, *, comment: str = "", file_format: DataFileType = DataFileType.time_explicit, number_format: str = "%.10e", ): + if isinstance(dataset, xr.Dataset) and "data" in dataset: + dataset = dataset.data + warn( + UserWarning( + "Saving the 'data' attribute of 'dataset' as a fallback." + "Result saving for ascii format only supports xarray.DataArray format, " + "please pass a xarray.DataArray instead of a xarray.Dataset " + "(e.g. dataset.data)." + ), + stacklevel=4, + ) data_file = ( TimeExplicitFile(filepath=file_name, dataset=dataset) if file_format is DataFileType.time_explicit diff --git a/glotaran/builtin/megacomplexes/clp_guide/__init__.py b/glotaran/builtin/megacomplexes/clp_guide/__init__.py new file mode 100644 index 000000000..3ad4e5c13 --- /dev/null +++ b/glotaran/builtin/megacomplexes/clp_guide/__init__.py @@ -0,0 +1 @@ +from glotaran.builtin.megacomplexes.clp_guide.clp_guide_megacomplex import ClpGuideMegacomplex diff --git a/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py b/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py new file mode 100644 index 000000000..54f171c11 --- /dev/null +++ b/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import numpy as np +import xarray as xr + +from glotaran.model import DatasetModel +from glotaran.model import Megacomplex +from glotaran.model import megacomplex + + +@megacomplex(exclusive=True, register_as="clp-guide", properties={"target": str}) +class ClpGuideMegacomplex(Megacomplex): + def calculate_matrix( + self, + dataset_model: DatasetModel, + indices: dict[str, int], + **kwargs, + ): + clp_label = [self.target] + matrix = np.ones((1, 1), dtype=np.float64) + return clp_label, matrix + + def index_dependent(self, dataset_model: DatasetModel) -> bool: + return False + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + pass diff --git a/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py b/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py new file mode 100644 index 000000000..85a78588f --- /dev/null +++ b/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py @@ -0,0 +1,64 @@ +import numpy as np + +from glotaran.analysis.optimize import optimize +from glotaran.analysis.simulation import simulate +from glotaran.builtin.megacomplexes.clp_guide import ClpGuideMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex +from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import create_gaussian_clp +from glotaran.model import Model +from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme + + +def test_clp_guide(): + + model = Model.from_dict( + { + "dataset_groups": {"default": {"link_clp": True}}, + "megacomplex": { + "mc1": { + "type": "decay-sequential", + "compartments": ["s1", "s2"], + "rates": ["1", "2"], + }, + "mc2": {"type": "clp-guide", "dimension": "time", "target": "s1"}, + }, + "dataset": { + "dataset1": {"megacomplex": ["mc1"]}, + "dataset2": {"megacomplex": ["mc2"]}, + }, + }, + megacomplex_types={ + "decay-sequential": DecaySequentialMegacomplex, + "clp-guide": ClpGuideMegacomplex, + }, + ) + + initial_parameters = ParameterGroup.from_list( + [101e-5, 501e-4, [1, {"vary": False, "non-negative": False}]] + ) + wanted_parameters = ParameterGroup.from_list( + [101e-4, 501e-3, [1, {"vary": False, "non-negative": False}]] + ) + + time = np.arange(0, 50, 1.5) + pixel = np.arange(600, 750, 5) + axis = {"time": time, "pixel": pixel} + + clp = create_gaussian_clp(["s1", "s2"], [7, 30], [620, 720], [10, 50], pixel) + + dataset1 = simulate(model, "dataset1", wanted_parameters, axis, clp) + dataset2 = clp.sel(clp_label=["s1"]).rename(clp_label="time") + data = {"dataset1": dataset1, "dataset2": dataset2} + + scheme = Scheme( + model=model, + parameters=initial_parameters, + data=data, + maximum_number_function_evaluations=20, + ) + result = optimize(scheme) + print(result.optimized_parameters) + + for label, param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index 64b7c93a3..6412f5a5f 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -11,7 +11,7 @@ from glotaran.project import Scheme -def _create_gaussian_clp(labels, amplitudes, centers, widths, axis): +def create_gaussian_clp(labels, amplitudes, centers, widths, axis): return xr.DataArray( [ amplitudes[i] * np.exp(-np.log(2) * np.square(2 * (axis - centers[i]) / widths[i])) @@ -179,9 +179,7 @@ class ThreeComponentParallel: axis = {"time": time, "pixel": pixel} - clp = _create_gaussian_clp( - ["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel - ) + clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel) class ThreeComponentSequential: @@ -240,9 +238,7 @@ class ThreeComponentSequential: pixel = np.arange(600, 750, 10) axis = {"time": time, "pixel": pixel} - clp = _create_gaussian_clp( - ["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel - ) + clp = create_gaussian_clp(["s1", "s2", "s3"], [7, 3, 30], [620, 670, 720], [10, 30, 50], pixel) @pytest.mark.parametrize( diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index d96addbe5..f1a54afcf 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -11,6 +11,18 @@ def index_dependent(dataset_model: DatasetModel) -> bool: + """Determine if a dataset_model is index dependent. + + Parameters + ---------- + dataset_model : DatasetModel + A dataset model instance. + + Returns + ------- + bool + Returns True if the dataset_model has an IRF that is index dependent (e.g. has dispersion). + """ return ( isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.is_index_dependent() ) diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index 188fccdf6..1fe82ed79 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -1,6 +1,8 @@ """The DatasetModel class.""" + from __future__ import annotations +import contextlib from collections import Counter from typing import TYPE_CHECKING @@ -39,16 +41,16 @@ class DatasetModel: def iterate_megacomplexes( self, - ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: - """Iterates of der dataset model's megacomplexes.""" + ) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: + """Iterates the dataset model's megacomplexes.""" for i, megacomplex in enumerate(self.megacomplex): scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None yield scale, megacomplex def iterate_global_megacomplexes( self, - ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: - """Iterates of der dataset model's global megacomplexes.""" + ) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: + """Iterates the dataset model's global megacomplexes.""" for i, megacomplex in enumerate(self.global_megacomplex): scale = ( self.global_megacomplex_scale[i] @@ -172,7 +174,7 @@ def get_global_axis(self) -> np.ndarray: @model_item_validator(False) def ensure_unique_megacomplexes(self, model: Model) -> list[str]: - """Ensure that unique megacomplexes Are only used once per dataset. + """Ensure that unique megacomplexes are only used once per dataset. Parameters ---------- @@ -184,20 +186,67 @@ def ensure_unique_megacomplexes(self, model: Model) -> list[str]: list[str] Error messages to be shown when the model gets validated. """ - glotaran_unique_megacomplex_types = [] - - for megacomplex_name in self.megacomplex: - try: - megacomplex_instance = model.megacomplex[megacomplex_name] - if type(megacomplex_instance).glotaran_unique() is True: - type_name = megacomplex_instance.type or megacomplex_instance.name - glotaran_unique_megacomplex_types.append(type_name) - except KeyError: - pass - - return [ - f"Multiple instances of unique megacomplex type {type_name!r} " - f"in dataset {self.label!r}" - for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common() - if count > 1 - ] + errors = [] + + def get_unique_errors(megacomplexes: list[str], is_global: bool) -> list[str]: + unique_types = [] + for megacomplex_name in megacomplexes: + with contextlib.suppress(KeyError): + megacomplex_instance = model.megacomplex[megacomplex_name] + if type(megacomplex_instance).glotaran_unique(): + type_name = megacomplex_instance.type or megacomplex_instance.name + unique_types.append(type_name) + this_errors = [ + f"Multiple instances of unique{' global ' if is_global else ' '}" + f"megacomplex type {type_name!r} in dataset {self.label!r}" + for type_name, count in Counter(unique_types).most_common() + if count > 1 + ] + + return this_errors + + if self.megacomplex: + errors += get_unique_errors(self.megacomplex, False) + if self.global_megacomplex: + errors += get_unique_errors(self.global_megacomplex, True) + + return errors + + @model_item_validator(False) + def ensure_exclusive_megacomplexes(self, model: Model) -> list[str]: + """Ensure that exclusive megacomplexes are the only megacomplex in the dataset model. + + Parameters + ---------- + model : Model + Model object using this dataset model. + + Returns + ------- + list[str] + Error messages to be shown when the model gets validated. + """ + + errors = [] + + def get_exclusive_errors(megacomplexes: list[str]) -> list[str]: + with contextlib.suppress(StopIteration): + exclusive_megacomplex = next( + model.megacomplex[label] + for label in megacomplexes + if label in model.megacomplex + and type(model.megacomplex[label]).glotaran_exclusive() + ) + if len(self.megacomplex) != 1: + return [ + f"Megacomplex '{type(exclusive_megacomplex)}' is exclusive and cannot be " + f"combined with other megacomplex in dataset model '{self.label}'." + ] + return [] + + if self.megacomplex: + errors += get_exclusive_errors(self.megacomplex) + if self.global_megacomplex: + errors += get_exclusive_errors(self.global_megacomplex) + + return errors diff --git a/glotaran/model/dataset_model.pyi b/glotaran/model/dataset_model.pyi index 96cd84ed0..475d90734 100644 --- a/glotaran/model/dataset_model.pyi +++ b/glotaran/model/dataset_model.pyi @@ -45,3 +45,4 @@ class DatasetModel: def get_model_axis(self) -> np.ndarray: ... def get_global_axis(self) -> np.ndarray: ... def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ... + def ensure_exclusive_megacomplexes(self, model: Model) -> list[str]: ... diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 5b6f0e103..d5cb9c988 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -37,6 +37,7 @@ def megacomplex( dataset_model_items: dict[str, dict[str, Any]] = None, dataset_properties: Any | dict[str, dict[str, Any]] = None, unique: bool = False, + exclusive: bool = False, register_as: str | None = None, ): """The `@megacomplex` decorator is intended to be used on subclasses of @@ -67,6 +68,7 @@ def decorator(cls): setattr(cls, "_glotaran_megacomplex_dataset_model_items", dataset_model_items) setattr(cls, "_glotaran_megacomplex_dataset_properties", dataset_properties) setattr(cls, "_glotaran_megacomplex_unique", unique) + setattr(cls, "_glotaran_megacomplex_exclusive", exclusive) megacomplex_type = model_item(properties=properties, has_type=True)(cls) @@ -140,3 +142,7 @@ def glotaran_dataset_properties(cls) -> str: @classmethod def glotaran_unique(cls) -> bool: return cls._glotaran_megacomplex_unique + + @classmethod + def glotaran_exclusive(cls) -> bool: + return cls._glotaran_megacomplex_exclusive diff --git a/glotaran/model/model.py b/glotaran/model/model.py index ab273d913..a43b37b1c 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -216,11 +216,13 @@ def _add_dataset_property(self, property_name: str, dataset_property: dict[str, if isinstance(self._dataset_properties, dict) else self._dataset_properties[property_name] ) + new_type = ( dataset_property["type"] if isinstance(dataset_property, dict) else dataset_property ) + if known_type != new_type: raise ModelError( f"Cannot add dataset property of type {property_name} as it was " @@ -383,7 +385,7 @@ def problem_list(self, parameters: ParameterGroup = None) -> list[str]: for item in items: problems += item.validate(self, parameters=parameters) else: - for _, item in items.items(): + for item in items.values(): problems += item.validate(self, parameters=parameters) return problems diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index a20ec8c3e..b1970cc2b 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -94,11 +94,16 @@ class MockMegacomplex6(Megacomplex): pass -@megacomplex(dimension="model", model_items={"test_item_simple": MockItemSimple}) +@megacomplex(dimension="model", exclusive=True) class MockMegacomplex7(Megacomplex): pass +@megacomplex(dimension="model", model_items={"test_item_simple": MockItemSimple}) +class MockMegacomplex8(Megacomplex): + pass + + @pytest.fixture def test_model_dict(): model_dict = { @@ -169,7 +174,12 @@ def test_model(test_model_dict): @pytest.fixture def model_error(): model_dict = { - "megacomplex": {"m1": {}, "m2": {"type": "type2"}, "m3": {"type": "type2"}}, + "megacomplex": { + "m1": {}, + "m2": {"type": "type2"}, + "m3": {"type": "type2"}, + "m4": {"type": "type3"}, + }, "test_item1": { "t1": { "param": "fool", @@ -184,7 +194,7 @@ def model_error(): "scale": "scale_1", }, "dataset2": { - "megacomplex": ["mrX"], + "megacomplex": ["mrX", "m4"], "scale": "scale_3", }, "dataset3": { @@ -197,6 +207,7 @@ def model_error(): megacomplex_types={ "type1": MockMegacomplex1, "type2": MockMegacomplex6, + "type3": MockMegacomplex7, }, ) @@ -321,9 +332,9 @@ def test_model_validity(test_model: Model, model_error: Model, parameter: Parame print(model_error.problem_list()) print(model_error.problem_list(parameter)) assert not model_error.valid() - assert len(model_error.problem_list()) == 5 + assert len(model_error.problem_list()) == 6 assert not model_error.valid(parameter) - assert len(model_error.problem_list(parameter)) == 9 + assert len(model_error.problem_list(parameter)) == 10 def test_items(test_model: Model): @@ -407,7 +418,7 @@ def test_fill(test_model: Model, parameter: ParameterGroup): def test_model_as_dict(): model_dict = { - "default_megacomplex": "type7", + "default_megacomplex": "type8", "megacomplex": { "m1": {"test_item_simple": "t2", "dimension": "model"}, }, @@ -433,7 +444,7 @@ def test_model_as_dict(): model = Model.from_dict( model_dict, megacomplex_types={ - "type7": MockMegacomplex7, + "type8": MockMegacomplex8, }, ) as_model_dict = model.as_dict() diff --git a/glotaran/project/result.py b/glotaran/project/result.py index 3cdaea408..5d627b9ae 100644 --- a/glotaran/project/result.py +++ b/glotaran/project/result.py @@ -27,6 +27,7 @@ from glotaran.project.dataclass_helpers import init_file_loadable_fields from glotaran.project.scheme import Scheme from glotaran.utils.io import DatasetMapping +from glotaran.utils.io import create_clp_guide_dataset from glotaran.utils.ipython import MarkdownStr if TYPE_CHECKING: @@ -310,6 +311,46 @@ def verify(self) -> bool: return True + def create_clp_guide_dataset(self, clp_label: str, dataset_name: str) -> xr.Dataset: + """Create dataset for clp guidance. + + Parameters + ---------- + clp_label : str + Label of the clp to guide. + dataset_name : str + Name of dataset to extract the guide from. + + Returns + ------- + xr.Dataset + DataArray containing the clp guide, with ``clp_label`` dimension replaced by the + model dimensions first value. + + Raises + ------ + ValueError + If ``dataset_name`` is not in result. + ValueError + If ``clp_labels`` is not in result. + + + Examples + -------- + Extracting the clp guide from an optimization result object. + + .. code-block:: python + + from glotaran.io import save_dataset + + clp_guide = result.create_clp_guide_dataset("species_1", "dataset_1") + save_dataset(clp_guide, "clp_guide__result_dataset_1__species_1.nc") + + + .. # noqa: DAR402 + """ + return create_clp_guide_dataset(self, clp_label=clp_label, dataset_name=dataset_name) + @deprecate( deprecated_qual_name_usage="glotaran.project.result.Result.get_dataset(dataset_label)", new_qual_name_usage=("glotaran.project.result.Result.data[dataset_label]"), @@ -340,5 +381,5 @@ def get_dataset(self, dataset_label: str) -> xr.Dataset: """ try: return self.data[dataset_label] - except KeyError: - raise ValueError(f"Unknown dataset '{dataset_label}'") + except KeyError as e: + raise ValueError(f"Unknown dataset '{dataset_label}'") from e diff --git a/glotaran/project/test/test_result.py b/glotaran/project/test/test_result.py index a4e7eb0fb..3172feff1 100644 --- a/glotaran/project/test/test_result.py +++ b/glotaran/project/test/test_result.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy as np import pytest import xarray as xr from IPython.core.formatters import format_display_data @@ -12,6 +13,7 @@ from glotaran.io import SavingOptions from glotaran.project.result import Result from glotaran.testing.simulated_data.sequential_spectral_decay import SCHEME +from glotaran.testing.simulated_data.shared_decay import SPECTRAL_AXIS @pytest.fixture(scope="session") @@ -45,6 +47,14 @@ def test_get_scheme(dummy_result: Result): ) +def test_result_create_clp_guide_dataset(dummy_result: Result): + """Check that clp guide has correct dimensions and dimension values.""" + clp_guide = dummy_result.create_clp_guide_dataset("species_1", "dataset_1") + assert clp_guide.data.shape == (1, dummy_result.data["dataset_1"].spectral.size) + assert np.allclose(clp_guide.coords["time"].item(), -1) + assert np.allclose(clp_guide.coords["spectral"].values, SPECTRAL_AXIS) + + @pytest.mark.parametrize("saving_options", [SAVING_OPTIONS_MINIMAL, SAVING_OPTIONS_DEFAULT]) def test_save_result(tmp_path: Path, saving_options: SavingOptions, dummy_result: Result): result_path = tmp_path / "test_result" diff --git a/glotaran/utils/io.py b/glotaran/utils/io.py index e8c6b0793..9d1d12845 100644 --- a/glotaran/utils/io.py +++ b/glotaran/utils/io.py @@ -21,6 +21,7 @@ import pandas as pd + from glotaran.project.result import Result from glotaran.typing.types import StrOrPath @@ -294,3 +295,90 @@ def make_path_absolute_if_relative(path: Path) -> Path: if not path.is_absolute(): path = get_script_dir(nesting=2) / path return path + + +def create_clp_guide_dataset( + result: Result | xr.Dataset, clp_label: str, dataset_name: str | None = None +) -> xr.Dataset: + """Create dataset for clp guidance. + + Parameters + ---------- + result: Result | xr.Dataset + Optimization result object or dataset, created with pyglotaran>=0.6.0. + clp_label : str + Label of the clp to guide. + dataset_name : str | None + Name of dataset to extract the guide from. Defaults to None. + + Returns + ------- + xr.Dataset + DataArray containing the clp guide, with ``clp_label`` dimension replaced by the + model dimensions first value. + + Raises + ------ + ValueError + If result is an instance of ``Result`` and ``dataset_name`` is ``None`` or not in result. + ValueError + If ``clp_labels`` is not in result. + ValueError + The result dataset was created with pyglotaran<0.6.0. + + Examples + -------- + Extracting the clp guide from an optimization result object. + + .. code-block:: python + + from glotaran.io import save_dataset + from glotaran.utils.io import create_clp_guide_dataset + + clp_guide = create_clp_guide_dataset(result, "species_1", "dataset_1") + save_dataset(clp_guide, "clp_guide__result_dataset_1__species_1.nc") + + Extracting the clp guide from a result dataset loaded from file. + + .. code-block:: python + + from glotaran.io import load_dataset + from glotaran.io import save_dataset + from glotaran.utils.io import create_clp_guide_dataset + + result_dataset = load_dataset("result_dataset_1.nc") + clp_guide = create_clp_guide_dataset(result_dataset, "species_1") + save_dataset(clp_guide, "clp_guide__result_dataset_1__species_1.nc") + + """ + if isinstance(result, xr.Dataset): + dataset = result + elif dataset_name is None or dataset_name not in result.data: + raise ValueError( + f"Unknown dataset {dataset_name!r}. " + f"Known datasets are:\n {list(result.data.keys())}" + ) + else: + dataset = result.data[dataset_name] + if clp_label not in dataset.clp_label: + raise ValueError( + f"Unknown clp_label {clp_label!r}. " + f"Known clp_labels are:\n {list(dataset.clp_label.values)}" + ) + if "model_dimension" not in dataset.attrs: + raise ValueError( + "Result dataset is missing attribute 'model_dimension', " + "which means that it was created with pyglotaran<0.6.0." + "Please recreate the result with the latest version of pyglotaran." + ) + + clp_values = dataset.clp.sel(clp_label=[clp_label]) + value_dimension = next(filter(lambda x: x != dataset.model_dimension, clp_values.dims)) + + return xr.DataArray( + clp_values.values.T, + coords={ + dataset.model_dimension: [dataset.coords[dataset.model_dimension][0].item()], + value_dimension: clp_values.coords[value_dimension].values, + }, + ).to_dataset(name="data") diff --git a/glotaran/utils/test/test_io.py b/glotaran/utils/test/test_io.py index a25ba12a8..dc73a3c37 100644 --- a/glotaran/utils/test/test_io.py +++ b/glotaran/utils/test/test_io.py @@ -12,8 +12,14 @@ from IPython.core.formatters import format_display_data from pandas.testing import assert_frame_equal +from glotaran.analysis.optimize import optimize +from glotaran.io import load_dataset from glotaran.io import save_dataset +from glotaran.project.result import Result +from glotaran.testing.simulated_data.sequential_spectral_decay import SCHEME +from glotaran.testing.simulated_data.shared_decay import SPECTRAL_AXIS from glotaran.utils.io import DatasetMapping +from glotaran.utils.io import create_clp_guide_dataset from glotaran.utils.io import load_datasets from glotaran.utils.io import relative_posix_path from glotaran.utils.io import safe_dataframe_fillna @@ -40,6 +46,13 @@ def dummy_datasets(tmp_path: Path) -> tuple[Path, xr.Dataset, xr.Dataset]: return tmp_path, ds1, ds2 +@pytest.fixture(scope="session") +def dummy_result(): + """Dummy result for testing.""" + print(SCHEME.data["dataset_1"]) + yield optimize(SCHEME, raise_exception=True) + + def test_dataset_mapping(ds_mapping: DatasetMapping): """Basic mapping functionality of ``DatasetMapping``.""" @@ -255,3 +268,72 @@ def test_safe_dataframe_replace(): safe_dataframe_replace(df, "not_a_column", np.inf, 2) assert_frame_equal(df, df2) + + +def test_create_clp_guide_dataset(dummy_result: Result): + """Check that clp guide has correct dimensions and dimension values.""" + clp_guide = create_clp_guide_dataset(dummy_result, "species_1", "dataset_1") + + assert clp_guide.data.shape == (1, dummy_result.data["dataset_1"].spectral.size) + assert np.allclose(clp_guide.coords["time"].item(), -1) + assert np.allclose(clp_guide.coords["spectral"].values, SPECTRAL_AXIS) + + clp_guide = create_clp_guide_dataset(dummy_result.data["dataset_1"], "species_1") + + assert clp_guide.data.shape == (1, dummy_result.data["dataset_1"].spectral.size) + assert np.allclose(clp_guide.coords["time"].item(), -1) + assert np.allclose(clp_guide.coords["spectral"].values, SPECTRAL_AXIS) + + +def test_create_clp_guide_dataset_errors(dummy_result: Result): + """Errors thrown when dataset or clp_label are not in result.""" + with pytest.raises(ValueError) as exc_info: + create_clp_guide_dataset(dummy_result, "species_1", "not-a-dataset") + + assert ( + str(exc_info.value) + == "Unknown dataset 'not-a-dataset'. Known datasets are:\n ['dataset_1']" + ) + + with pytest.raises(ValueError) as exc_info: + create_clp_guide_dataset(dummy_result, "not-a-species", "dataset_1") + + assert ( + str(exc_info.value) == "Unknown clp_label 'not-a-species'. Known clp_labels are:\n " + "['species_1', 'species_2', 'species_3']" + ) + + dummy_dataset = dummy_result.data["dataset_1"].copy() + del dummy_dataset.attrs["model_dimension"] + + with pytest.raises(ValueError) as exc_info: + create_clp_guide_dataset(dummy_dataset, "species_1") + + assert ( + str(exc_info.value) == "Result dataset is missing attribute 'model_dimension', " + "which means that it was created with pyglotaran<0.6.0." + "Please recreate the result with the latest version of pyglotaran." + ) + + +def test_extract_sas_ascii_round_trip(dummy_result: Result, tmp_path: Path): + """Save to and then load from ascii results in the same data (spectrum).""" + tmp_file = tmp_path / "sas.ascii" + + sas = create_clp_guide_dataset(dummy_result, "species_1", "dataset_1") + with pytest.warns(UserWarning) as rec_warn: + save_dataset(sas, tmp_file) + + assert len(rec_warn) == 1 + assert Path(rec_warn[0].filename).samefile(__file__) + assert rec_warn[0].message.args[0] == ( + "Saving the 'data' attribute of 'dataset' as a fallback." + "Result saving for ascii format only supports xarray.DataArray format, " + "please pass a xarray.DataArray instead of a xarray.Dataset (e.g. dataset.data)." + ) + + loaded_sas = load_dataset(tmp_file, prepare=False) + + for dim in sas.dims: + assert all(sas.coords[dim] == loaded_sas.coords[dim]), f"Coordinate {dim} mismatch" + assert np.allclose(sas.data.values, loaded_sas.data.values) diff --git a/pyproject.toml b/pyproject.toml index f04e190af..031f79a3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ exclude = [ "benchmark/*" ] ignore-init-module = true -fail-under = 59 +fail-under = 63 [tool.nbqa.addopts] flake8 = [