From 671c6bdd0f6c052e6e11d76374b75216e8f719aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Sat, 16 Oct 2021 16:12:05 +0200 Subject: [PATCH] Changed plugininterface to use SavingOptions --- .../io/ascii/wavelength_time_explicit_file.py | 1 + glotaran/builtin/io/folder/folder_plugin.py | 166 ++++++++++++------ .../io/folder/test/test_folder_plugin.py | 7 +- .../builtin/io/yml/test/test_save_result.py | 7 +- glotaran/builtin/io/yml/yml.py | 56 +++++- .../modules/test/test_project_result.py | 4 +- glotaran/io/__init__.py | 6 +- glotaran/io/interface.py | 31 +++- .../plugin_system/data_io_registration.py | 1 + .../plugin_system/project_io_registration.py | 21 +-- .../test/test_data_io_registration.py | 1 + .../test/test_project_io_registration.py | 4 + glotaran/project/result.py | 14 +- glotaran/project/test/test_result.py | 59 +++++++ 14 files changed, 292 insertions(+), 86 deletions(-) diff --git a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py index 9bceecde5..8363190b7 100644 --- a/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py +++ b/glotaran/builtin/io/ascii/wavelength_time_explicit_file.py @@ -279,6 +279,7 @@ def save_dataset( dataset: xr.DataArray, file_name: str, *, + data_filter: list[str] = None, comment: str = "", file_format: DataFileType = DataFileType.time_explicit, number_format: str = "%.10e", diff --git a/glotaran/builtin/io/folder/folder_plugin.py b/glotaran/builtin/io/folder/folder_plugin.py index 58171ef98..7ea461935 100644 --- a/glotaran/builtin/io/folder/folder_plugin.py +++ b/glotaran/builtin/io/folder/folder_plugin.py @@ -9,15 +9,18 @@ from pathlib import Path from typing import TYPE_CHECKING +from glotaran.io import SAVING_OPTIONS_DEFAULT from glotaran.io import save_dataset from glotaran.io import save_model from glotaran.io import save_parameters +from glotaran.io import save_result from glotaran.io import save_scheme from glotaran.io.interface import ProjectIoInterface -from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT from glotaran.plugin_system.project_io_registration import register_project_io if TYPE_CHECKING: + from os import PathLike + from glotaran.plugin_system.project_io_registration import SavingOptions from glotaran.project import Result @@ -33,8 +36,10 @@ class FolderProjectIo(ProjectIoInterface): def save_result( self, result: Result, - result_path: str, + result_path: str | PathLike[str], + format_name: str = None, *, + allow_overwrite: bool = False, saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, ) -> list[str]: """Save the result to a given folder. @@ -57,11 +62,16 @@ def save_result( Parameters ---------- result : Result - Result instance to be saved. - result_path : str - The path to the folder in which to save the result. + :class:`Result` instance to write. + result_path : str | PathLike[str] + Path to write the result data to. + format_name : str + Format the result should be saved in, if not provided and it is a file + it will be inferred from the file extension. + allow_overwrite : bool + Whether or not to allow overwriting existing files, by default False saving_options : SavingOptions - Options for saving the the result. + Options for the saved result. Returns @@ -77,52 +87,106 @@ def save_result( result_folder = Path(result_path) if result_folder.is_file(): raise ValueError(f"The path '{result_folder}' is not a directory.") - result_folder.mkdir(parents=True, exist_ok=True) - - paths = [] - if saving_options.report: - report_path = result_folder / "result.md" - report_path.write_text(str(result.markdown())) - paths.append(report_path.as_posix()) - - model_path = result_folder / "model.yml" - save_model(result.scheme.model, model_path, allow_overwrite=True) - paths.append(model_path.as_posix()) - - initial_parameters_path = f"initial_parameters.{saving_options.parameter_format}" - save_parameters( - result.scheme.parameters, - result_folder / initial_parameters_path, - format_name=saving_options.parameter_format, - allow_overwrite=True, + + return save_result( + result, + result_folder / "glotaran_result.yml", + allow_overwrite=allow_overwrite, + saving_options=saving_options, ) - paths.append((result_folder / initial_parameters_path).as_posix()) - optimized_parameters_path = f"optimized_parameters.{saving_options.parameter_format}" - save_parameters( - result.optimized_parameters, - result_folder / optimized_parameters_path, - format_name=saving_options.parameter_format, + +def save_result_to_folder( + result: Result, + result_path: str | PathLike[str], + saving_options: SavingOptions, +) -> list[str]: + """Save the result to a given folder. + + Returns a list with paths of all saved items. + The following files are saved if not configured otherwise: + * `result.md`: The result with the model formatted as markdown text. + * `model.yml`: Model spec file. + * `scheme.yml`: Scheme spec file. + * `initial_parameters.csv`: Initially used parameters. + * `optimized_parameters.csv`: The optimized parameter as csv file. + * `parameter_history.csv`: Parameter changes over the optimization + * `{dataset_label}.nc`: The result data for each dataset as NetCDF file. + + Note + ---- + As a side effect it populates the file path properties of ``result`` which can be + used in other plugins (e.g. the ``yml`` save_result). + + Parameters + ---------- + result : Result + :class:`Result` instance to write. + result_path : str | PathLike[str] + Path to write the result data to. + saving_options : SavingOptions + Options for the saved result. + + Returns + ------- + list[str] + List of file paths which were created. + + Raises + ------ + ValueError + If ``result_path`` is a file. + """ + result_folder = Path(result_path) + if result_folder.is_file(): + raise ValueError(f"The path '{result_folder}' is not a directory.") + result_folder.mkdir(parents=True, exist_ok=True) + + paths = [] + if saving_options.report: + report_path = result_folder / "result.md" + report_path.write_text(str(result.markdown())) + paths.append(report_path.as_posix()) + + model_path = result_folder / "model.yml" + save_model(result.scheme.model, model_path, allow_overwrite=True) + paths.append(model_path.as_posix()) + + initial_parameters_path = f"initial_parameters.{saving_options.parameter_format}" + save_parameters( + result.scheme.parameters, + result_folder / initial_parameters_path, + format_name=saving_options.parameter_format, + allow_overwrite=True, + ) + paths.append((result_folder / initial_parameters_path).as_posix()) + + optimized_parameters_path = f"optimized_parameters.{saving_options.parameter_format}" + save_parameters( + result.optimized_parameters, + result_folder / optimized_parameters_path, + format_name=saving_options.parameter_format, + allow_overwrite=True, + ) + paths.append((result_folder / optimized_parameters_path).as_posix()) + + scheme_path = result_folder / "scheme.yml" + save_scheme(result.scheme, scheme_path, allow_overwrite=True) + paths.append(scheme_path.as_posix()) + + parameter_history_path = result_folder / "parameter_history.csv" + result.parameter_history.to_csv(parameter_history_path) + paths.append(parameter_history_path.as_posix()) + + for label, dataset in result.data.items(): + data_path = result_folder / f"{label}.{saving_options.data_format}" + save_dataset( + dataset, + data_path, + format_name=saving_options.data_format, allow_overwrite=True, + data_filters=saving_options.data_filter, ) - paths.append((result_folder / optimized_parameters_path).as_posix()) - - scheme_path = result_folder / "scheme.yml" - save_scheme(result.scheme, scheme_path, allow_overwrite=True) - paths.append(scheme_path.as_posix()) - - parameter_history_path = result_folder / "parameter_history.csv" - result.parameter_history.to_csv(parameter_history_path) - paths.append(parameter_history_path.as_posix()) - - for label, dataset in result.data.items(): - data_path = result_folder / f"{label}.{saving_options.data_format}" - save_dataset( - dataset, - data_path, - format_name=saving_options.data_format, - allow_overwrite=True, - ) - paths.append(data_path.as_posix()) - - return paths + paths.append(data_path.as_posix()) + + return paths diff --git a/glotaran/builtin/io/folder/test/test_folder_plugin.py b/glotaran/builtin/io/folder/test/test_folder_plugin.py index 8d200a7b9..4efe4b6ef 100644 --- a/glotaran/builtin/io/folder/test/test_folder_plugin.py +++ b/glotaran/builtin/io/folder/test/test_folder_plugin.py @@ -18,7 +18,7 @@ def dummy_result(): yield optimize(SCHEME, raise_exception=True) -@pytest.mark.parametrize("format_name", ("folder", "legacy")) +@pytest.mark.parametrize("format_name", ["folder", "legacy"]) def test_save_result_folder( tmp_path: Path, dummy_result: Result, @@ -27,6 +27,7 @@ def test_save_result_folder( """Check all files exist.""" result_dir = tmp_path / "testresult" + assert not result_dir.exists() save_paths = save_result( result_path=str(result_dir), format_name=format_name, result=dummy_result ) @@ -45,7 +46,7 @@ def test_save_result_folder( assert (result_dir / wanted).as_posix() in save_paths -@pytest.mark.parametrize("format_name", ("folder", "legacy")) +@pytest.mark.parametrize("format_name", ["folder", "legacy"]) def test_save_result_folder_error_path_is_file( tmp_path: Path, dummy_result: Result, @@ -53,7 +54,7 @@ def test_save_result_folder_error_path_is_file( ): """Raise error if result_path is a file without extension and overwrite is true.""" - result_dir = tmp_path / "testresult" + result_dir = tmp_path / "testresulterror" result_dir.touch() with pytest.raises(ValueError, match="The path '.+?' is not a directory."): diff --git a/glotaran/builtin/io/yml/test/test_save_result.py b/glotaran/builtin/io/yml/test/test_save_result.py index 87f8669a9..f557c0326 100644 --- a/glotaran/builtin/io/yml/test/test_save_result.py +++ b/glotaran/builtin/io/yml/test/test_save_result.py @@ -16,8 +16,8 @@ @pytest.fixture(scope="session") def dummy_result(): """Dummy result for testing.""" + print(SCHEME.data["dataset_1"]) scheme = replace(SCHEME, maximum_number_function_evaluations=1) - print(scheme.data["dataset_1"]) yield optimize(scheme, raise_exception=True) @@ -58,4 +58,7 @@ def test_save_result_yml( assert (result_dir / "dataset_1.nc").exists() # We can't check equality due to numerical fluctuations - assert expected in (result_dir / "result.yml").read_text() + got = (result_dir / "result.yml").read_text() + print(got) + assert expected in got + save_result(result_path=tmp_path / "result.yml", result=dummy_result) diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index 167f98309..d5872207a 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -5,11 +5,12 @@ from ruamel.yaml import YAML +from glotaran.builtin.io.folder.folder_plugin import save_result_to_folder from glotaran.deprecation.modules.builtin_io_yml import model_spec_deprecations from glotaran.deprecation.modules.builtin_io_yml import scheme_spec_deprecations +from glotaran.io import SAVING_OPTIONS_DEFAULT from glotaran.io import ProjectIoInterface from glotaran.io import register_project_io -from glotaran.io import save_result from glotaran.model import Model from glotaran.parameter import ParameterGroup from glotaran.project import Result @@ -19,12 +20,15 @@ from glotaran.utils.sanitize import sanitize_yaml if TYPE_CHECKING: + from os import PathLike from typing import Any from typing import Mapping from ruamel.yaml.nodes import ScalarNode from ruamel.yaml.representer import BaseRepresenter + from glotaran.plugin_system.project_io_registration import SavingOptions + @register_project_io(["yml", "yaml", "yml_str"]) class YmlProjectIo(ProjectIoInterface): @@ -126,11 +130,35 @@ def load_result(self, result_path: str) -> Result: Result :class:`Result` instance created from the saved format. """ + spec = self._load_yml(result_path) return fromdict(Result, spec, folder=Path(result_path).parent) - def save_result(self, result: Result, result_path: str): - """Write a :class:`Result` instance to a spec file. + def save_result( + self, + result: Result, + result_path: str | PathLike[str], + format_name: str = None, + *, + allow_overwrite: bool = False, + saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, + ) -> list[str]: + """Save the result to a given folder. + + Returns a list with paths of all saved items. + The following files are saved if not configured otherwise: + * `result.md`: The result with the model formatted as markdown text. + * `model.yml`: Model spec file. + * `scheme.yml`: Scheme spec file. + * `initial_parameters.csv`: Initially used parameters. + * `optimized_parameters.csv`: The optimized parameter as csv file. + * `parameter_history.csv`: Parameter changes over the optimization + * `{dataset_label}.nc`: The result data for each dataset as NetCDF file. + + Note + ---- + As a side effect it populates the file path properties of ``result`` which can be + used in other plugins (e.g. the ``yml`` save_result). Parameters ---------- @@ -138,10 +166,30 @@ def save_result(self, result: Result, result_path: str): :class:`Result` instance to write. result_path : str | PathLike[str] Path to write the result data to. + format_name : str + Format the result should be saved in, if not provided and it is a file + it will be inferred from the file extension. + allow_overwrite : bool + Whether or not to allow overwriting existing files, by default False + saving_options : SavingOptions + Options for the saved result. + + + Returns + ------- + list[str] + List of file paths which were created. + + Raises + ------ + ValueError + If ``result_path`` is a file. """ - save_result(result, Path(result_path).parent.as_posix(), format_name="folder") + paths = save_result_to_folder(result, Path(result_path).parent.as_posix(), saving_options) result_dict = asdict(result, folder=Path(result_path).parent) _write_dict(result_path, result_dict) + paths.append(result_path) + return paths def _load_yml(self, file_name: str) -> dict[str, Any]: yaml = YAML() diff --git a/glotaran/deprecation/modules/test/test_project_result.py b/glotaran/deprecation/modules/test/test_project_result.py index 3e72888bb..697033b35 100644 --- a/glotaran/deprecation/modules/test/test_project_result.py +++ b/glotaran/deprecation/modules/test/test_project_result.py @@ -16,7 +16,7 @@ def dummy_result(): yield optimize(SCHEME, raise_exception=True) -def test_Result_get_dataset_method(dummy_result: Result): +def test_result_get_dataset_method(dummy_result: Result): """Result.get_dataset(dataset_label) gives correct dataset.""" _, result = deprecation_warning_on_call_test_helper( @@ -26,7 +26,7 @@ def test_Result_get_dataset_method(dummy_result: Result): assert result == dummy_result.data["dataset_1"] -def test_Result_get_dataset_method_error(dummy_result: Result): +def test_result_get_dataset_method_error(dummy_result: Result): """Result.get_dataset(dataset_label) error on wrong key.""" with pytest.raises(ValueError, match="Unknown dataset 'foo'"): diff --git a/glotaran/io/__init__.py b/glotaran/io/__init__.py index 343456a92..ef3e254ed 100644 --- a/glotaran/io/__init__.py +++ b/glotaran/io/__init__.py @@ -6,8 +6,11 @@ reexports functions from the pluginsystem from a common place. """ +from glotaran.io.interface import SAVING_OPTIONS_DEFAULT +from glotaran.io.interface import SAVING_OPTIONS_MINIMAL from glotaran.io.interface import DataIoInterface from glotaran.io.interface import ProjectIoInterface +from glotaran.io.interface import SavingOptions from glotaran.io.prepare_dataset import prepare_time_trace_dataset from glotaran.plugin_system.data_io_registration import data_io_plugin_table from glotaran.plugin_system.data_io_registration import get_dataloader @@ -17,9 +20,6 @@ from glotaran.plugin_system.data_io_registration import save_dataset from glotaran.plugin_system.data_io_registration import set_data_plugin from glotaran.plugin_system.data_io_registration import show_data_io_method_help -from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT -from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_MINIMAL -from glotaran.plugin_system.project_io_registration import SavingOptions from glotaran.plugin_system.project_io_registration import get_project_io_method from glotaran.plugin_system.project_io_registration import load_model from glotaran.plugin_system.project_io_registration import load_parameters diff --git a/glotaran/io/interface.py b/glotaran/io/interface.py index f6c5e5019..aacde089b 100644 --- a/glotaran/io/interface.py +++ b/glotaran/io/interface.py @@ -12,10 +12,12 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable + from typing import Literal from typing import Union import xarray as xr @@ -29,6 +31,20 @@ DataSaver = Callable[[str, Union[xr.Dataset, xr.DataArray]], None] +@dataclass +class SavingOptions: + """A collection of options for result saving.""" + + data_filter: list[str] | None = None + data_format: Literal["nc"] = "nc" + parameter_format: Literal["csv"] = "csv" + report: bool = True + + +SAVING_OPTIONS_DEFAULT = SavingOptions() +SAVING_OPTIONS_MINIMAL = SavingOptions(data_filter=["fitted_data", "residual"], report=False) + + class DataIoInterface: """Baseclass for Data IO plugins.""" @@ -62,9 +78,7 @@ def load_dataset(self, file_name: str) -> xr.Dataset | xr.DataArray: raise NotImplementedError(f"""Cannot read data with format: {self.format!r}""") def save_dataset( - self, - dataset: xr.Dataset | xr.DataArray, - file_name: str, + self, dataset: xr.Dataset | xr.DataArray, file_name: str, data_filters: list[str] = None ): """Save data from :xarraydoc:`Dataset` to a file (**NOT IMPLEMENTED**). @@ -74,6 +88,8 @@ def save_dataset( Dataset to be saved to file. file_name : str File to write the data to. + data_filters : list[str] + A list of dataset items items to save. .. # noqa: DAR101 @@ -218,7 +234,12 @@ def load_result(self, result_path: str) -> Result: """ raise NotImplementedError(f"Cannot read result with format {self.format!r}") - def save_result(self, result: Result, result_path: str) -> list[str] | None: + def save_result( + self, + result: Result, + result_path: str, + saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, + ) -> list[str] | None: """Save a Result instance to a spec file (**NOT IMPLEMENTED**). Parameters @@ -227,6 +248,8 @@ def save_result(self, result: Result, result_path: str) -> list[str] | None: Result instance to save to specs file. result_path : str Path to write the result data to. + saving_options : SavingOptions + Options for the saved result. .. # noqa: DAR101 diff --git a/glotaran/plugin_system/data_io_registration.py b/glotaran/plugin_system/data_io_registration.py index 549ac20dc..12215ef57 100644 --- a/glotaran/plugin_system/data_io_registration.py +++ b/glotaran/plugin_system/data_io_registration.py @@ -242,6 +242,7 @@ def save_dataset( io.save_dataset( # type: ignore[call-arg] file_name=Path(file_name).as_posix(), dataset=dataset, + data_filters=data_filters, **kwargs, ) dataset.attrs["loader"] = load_dataset diff --git a/glotaran/plugin_system/project_io_registration.py b/glotaran/plugin_system/project_io_registration.py index 6ca838b35..d3b9d0b44 100644 --- a/glotaran/plugin_system/project_io_registration.py +++ b/glotaran/plugin_system/project_io_registration.py @@ -8,14 +8,15 @@ """ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING from typing import TypeVar from tabulate import tabulate +from glotaran.io.interface import SAVING_OPTIONS_DEFAULT from glotaran.io.interface import ProjectIoInterface +from glotaran.io.interface import SavingOptions from glotaran.plugin_system.base_registry import __PluginRegistry from glotaran.plugin_system.base_registry import add_instantiated_plugin_to_registry from glotaran.plugin_system.base_registry import get_method_from_plugin @@ -54,20 +55,6 @@ Literal["save_result"], ) - -@dataclass -class SavingOptions: - """A collection of options for result saving.""" - - data_filter: list[str] | None = None - data_format: Literal["nc"] = "nc" - parameter_format: Literal["csv"] = "csv" - report: bool = True - - -SAVING_OPTIONS_DEFAULT = SavingOptions() -SAVING_OPTIONS_MINIMAL = SavingOptions(data_filter=["fitted_data", "residual"], report=False) - PROJECT_IO_METHODS = ( "load_model", "save_model", @@ -447,6 +434,7 @@ def save_result( *, allow_overwrite: bool = False, update_source_path: bool = True, + saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, **kwargs: Any, ) -> list[str] | None: """Write a :class:`Result` instance to a spec file. @@ -465,6 +453,8 @@ def save_result( update_source_path: bool Whether or not to update the ``source_path`` attribute to ``result_path`` when saving. by default True + saving_options : SavingOptions + Options for the saved result. **kwargs : Any Additional keyword arguments passes to the ``save_result`` implementation of the project io plugin. @@ -481,6 +471,7 @@ def save_result( paths = io.save_result( # type: ignore[call-arg] result_path=Path(result_path).as_posix(), result=result, + saving_options=saving_options, **kwargs, ) if update_source_path is True: diff --git a/glotaran/plugin_system/test/test_data_io_registration.py b/glotaran/plugin_system/test/test_data_io_registration.py index 90efd1118..dcbd43b56 100644 --- a/glotaran/plugin_system/test/test_data_io_registration.py +++ b/glotaran/plugin_system/test/test_data_io_registration.py @@ -48,6 +48,7 @@ def save_dataset( # type:ignore[override] self, file_name: StrOrPath, dataset: xr.Dataset | xr.DataArray, + data_filters: list[str] = None, *, result_container: dict[str, Any], **kwargs: Any, diff --git a/glotaran/plugin_system/test/test_project_io_registration.py b/glotaran/plugin_system/test/test_project_io_registration.py index ae0c1c0f8..f9b69d246 100644 --- a/glotaran/plugin_system/test/test_project_io_registration.py +++ b/glotaran/plugin_system/test/test_project_io_registration.py @@ -12,6 +12,8 @@ from glotaran.parameter import ParameterGroup from glotaran.plugin_system.base_registry import PluginOverwriteWarning from glotaran.plugin_system.base_registry import __PluginRegistry +from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT +from glotaran.plugin_system.project_io_registration import SavingOptions from glotaran.plugin_system.project_io_registration import get_project_io from glotaran.plugin_system.project_io_registration import get_project_io_method from glotaran.plugin_system.project_io_registration import is_known_project_format @@ -116,6 +118,8 @@ def save_result( # type:ignore[override] self, result: Result, result_path: StrOrPath, + *, + saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, **kwargs: Any, ): result.func_args.update( # type:ignore[attr-defined] diff --git a/glotaran/project/result.py b/glotaran/project/result.py index 9fb6bd9bf..586a56a28 100644 --- a/glotaran/project/result.py +++ b/glotaran/project/result.py @@ -15,6 +15,8 @@ from tabulate import tabulate from glotaran.deprecation import deprecate +from glotaran.io import SAVING_OPTIONS_DEFAULT +from glotaran.io import SavingOptions from glotaran.io import load_result from glotaran.io import save_result from glotaran.model import Model @@ -246,13 +248,15 @@ def __str__(self) -> str: """Overwrite of ``__str__``.""" return str(self.markdown(with_model=False)) - def save(self, path: str) -> list[str]: + def save(self, path: str, saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT) -> list[str]: """Save the result to given folder. Parameters ---------- path : str The path to the folder in which to save the result. + saving_options : SavingOptions + Options for the saved result. Returns ------- @@ -261,7 +265,13 @@ def save(self, path: str) -> list[str]: """ return cast( List[str], - save_result(result_path=path, result=self, format_name="folder", allow_overwrite=True), + save_result( + result_path=path, + result=self, + format_name="folder", + allow_overwrite=True, + saving_options=saving_options, + ), ) def recreate(self) -> Result: diff --git a/glotaran/project/test/test_result.py b/glotaran/project/test/test_result.py index 88aff2938..2a2e04e7e 100644 --- a/glotaran/project/test/test_result.py +++ b/glotaran/project/test/test_result.py @@ -1,9 +1,15 @@ from __future__ import annotations +from pathlib import Path + import pytest +import xarray as xr from IPython.core.formatters import format_display_data from glotaran.analysis.optimize import optimize +from glotaran.io import SAVING_OPTIONS_DEFAULT +from glotaran.io import SAVING_OPTIONS_MINIMAL +from glotaran.io import SavingOptions from glotaran.project.result import Result from glotaran.testing.sequential_spectral_decay import SCHEME @@ -27,3 +33,56 @@ def test_result_ipython_rendering(dummy_result: Result): assert "text/markdown" in rendered_markdown_return assert rendered_markdown_return["text/markdown"].startswith("| Optimization Result") + + +def test_get_scheme(dummy_result: Result): + scheme = dummy_result.get_scheme() + assert "residual" not in dummy_result.scheme.data["dataset_1"] + assert "residual" not in scheme.data["dataset_1"] + assert all(scheme.parameters.to_dataframe() != dummy_result.scheme.parameters.to_dataframe()) + assert all( + scheme.parameters.to_dataframe() == dummy_result.optimized_parameters.to_dataframe() + ) + + +@pytest.mark.parametrize("saving_options", [SAVING_OPTIONS_MINIMAL, SAVING_OPTIONS_DEFAULT]) +def test_save_result(tmp_path: Path, saving_options: SavingOptions, dummy_result: Result): + result_path = tmp_path / "test_result" + dummy_result.save(str(result_path), saving_options=saving_options) + files_must_exist = [ + "glotaran_result.yml", + "scheme.yml", + "model.yml", + "initial_parameters.csv", + "optimized_parameters.csv", + "parameter_history.csv", + "dataset_1.nc", + ] + files_must_not_exist = [] + if saving_options.report: + files_must_exist.append("result.md") + else: + files_must_not_exist.append("result.md") + + for file in files_must_exist: + assert (result_path / file).exists() + + for file in files_must_not_exist: + assert not (result_path / file).exists() + + dataset_path = result_path / "dataset_1.nc" + assert dataset_path.exists() + dataset = xr.open_dataset(dataset_path) + print(dataset) + if saving_options.data_filter is not None: + assert len(saving_options.data_filter) == len(dataset) + assert all(d in dataset for d in saving_options.data_filter) + + +def test_recreate(dummy_result): + recreated_result = dummy_result.recreate() + assert recreated_result.success + + +def test_verify(dummy_result): + assert dummy_result.verify()