Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add SavingOptions to save_result API #874

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def save_dataset(
dataset: xr.DataArray,
file_name: str,
*,
data_filter: list[str] = None,
comment: str = "",
file_format: DataFileType = DataFileType.time_explicit,
number_format: str = "%.10e",
Expand Down
166 changes: 115 additions & 51 deletions glotaran/builtin/io/folder/folder_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
from pathlib import Path
from typing import TYPE_CHECKING

from glotaran.io import SAVING_OPTIONS_DEFAULT
from glotaran.io import save_dataset
from glotaran.io import save_model
from glotaran.io import save_parameters
from glotaran.io import save_result
from glotaran.io import save_scheme
from glotaran.io.interface import ProjectIoInterface
from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT
from glotaran.plugin_system.project_io_registration import register_project_io

if TYPE_CHECKING:
from os import PathLike

from glotaran.plugin_system.project_io_registration import SavingOptions
from glotaran.project import Result

Expand All @@ -33,8 +36,10 @@ class FolderProjectIo(ProjectIoInterface):
def save_result(
self,
result: Result,
result_path: str,
result_path: str | PathLike[str],
format_name: str = None,
*,
allow_overwrite: bool = False,
saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT,
) -> list[str]:
"""Save the result to a given folder.
Expand All @@ -57,11 +62,16 @@ def save_result(
Parameters
----------
result : Result
Result instance to be saved.
result_path : str
The path to the folder in which to save the result.
:class:`Result` instance to write.
result_path : str | PathLike[str]
Path to write the result data to.
format_name : str
Format the result should be saved in, if not provided and it is a file
it will be inferred from the file extension.
allow_overwrite : bool
Whether or not to allow overwriting existing files, by default False
saving_options : SavingOptions
Options for saving the the result.
Options for the saved result.
Returns
Expand All @@ -77,52 +87,106 @@ def save_result(
result_folder = Path(result_path)
if result_folder.is_file():
raise ValueError(f"The path '{result_folder}' is not a directory.")
result_folder.mkdir(parents=True, exist_ok=True)

paths = []
if saving_options.report:
report_path = result_folder / "result.md"
report_path.write_text(str(result.markdown()))
paths.append(report_path.as_posix())

model_path = result_folder / "model.yml"
save_model(result.scheme.model, model_path, allow_overwrite=True)
paths.append(model_path.as_posix())

initial_parameters_path = f"initial_parameters.{saving_options.parameter_format}"
save_parameters(
result.scheme.parameters,
result_folder / initial_parameters_path,
format_name=saving_options.parameter_format,
allow_overwrite=True,

return save_result(
result,
result_folder / "glotaran_result.yml",
allow_overwrite=allow_overwrite,
saving_options=saving_options,
)
paths.append((result_folder / initial_parameters_path).as_posix())

optimized_parameters_path = f"optimized_parameters.{saving_options.parameter_format}"
save_parameters(
result.optimized_parameters,
result_folder / optimized_parameters_path,
format_name=saving_options.parameter_format,

def save_result_to_folder(
result: Result,
result_path: str | PathLike[str],
saving_options: SavingOptions,
) -> list[str]:
"""Save the result to a given folder.
Returns a list with paths of all saved items.
The following files are saved if not configured otherwise:
* `result.md`: The result with the model formatted as markdown text.
* `model.yml`: Model spec file.
* `scheme.yml`: Scheme spec file.
* `initial_parameters.csv`: Initially used parameters.
* `optimized_parameters.csv`: The optimized parameter as csv file.
* `parameter_history.csv`: Parameter changes over the optimization
* `{dataset_label}.nc`: The result data for each dataset as NetCDF file.
Note
----
As a side effect it populates the file path properties of ``result`` which can be
used in other plugins (e.g. the ``yml`` save_result).
Parameters
----------
result : Result
:class:`Result` instance to write.
result_path : str | PathLike[str]
Path to write the result data to.
saving_options : SavingOptions
Options for the saved result.
Returns
-------
list[str]
List of file paths which were created.
Raises
------
ValueError
If ``result_path`` is a file.
"""
result_folder = Path(result_path)
if result_folder.is_file():
raise ValueError(f"The path '{result_folder}' is not a directory.")
result_folder.mkdir(parents=True, exist_ok=True)

paths = []
if saving_options.report:
report_path = result_folder / "result.md"
report_path.write_text(str(result.markdown()))
paths.append(report_path.as_posix())

model_path = result_folder / "model.yml"
save_model(result.scheme.model, model_path, allow_overwrite=True)
paths.append(model_path.as_posix())

initial_parameters_path = f"initial_parameters.{saving_options.parameter_format}"
save_parameters(
result.scheme.parameters,
result_folder / initial_parameters_path,
format_name=saving_options.parameter_format,
allow_overwrite=True,
)
paths.append((result_folder / initial_parameters_path).as_posix())

optimized_parameters_path = f"optimized_parameters.{saving_options.parameter_format}"
save_parameters(
result.optimized_parameters,
result_folder / optimized_parameters_path,
format_name=saving_options.parameter_format,
allow_overwrite=True,
)
paths.append((result_folder / optimized_parameters_path).as_posix())

scheme_path = result_folder / "scheme.yml"
save_scheme(result.scheme, scheme_path, allow_overwrite=True)
paths.append(scheme_path.as_posix())

parameter_history_path = result_folder / "parameter_history.csv"
result.parameter_history.to_csv(parameter_history_path)
paths.append(parameter_history_path.as_posix())

for label, dataset in result.data.items():
data_path = result_folder / f"{label}.{saving_options.data_format}"
save_dataset(
dataset,
data_path,
format_name=saving_options.data_format,
allow_overwrite=True,
data_filters=saving_options.data_filter,
)
paths.append((result_folder / optimized_parameters_path).as_posix())

scheme_path = result_folder / "scheme.yml"
save_scheme(result.scheme, scheme_path, allow_overwrite=True)
paths.append(scheme_path.as_posix())

parameter_history_path = result_folder / "parameter_history.csv"
result.parameter_history.to_csv(parameter_history_path)
paths.append(parameter_history_path.as_posix())

for label, dataset in result.data.items():
data_path = result_folder / f"{label}.{saving_options.data_format}"
save_dataset(
dataset,
data_path,
format_name=saving_options.data_format,
allow_overwrite=True,
)
paths.append(data_path.as_posix())

return paths
paths.append(data_path.as_posix())

return paths
7 changes: 4 additions & 3 deletions glotaran/builtin/io/folder/test/test_folder_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dummy_result():
yield optimize(SCHEME, raise_exception=True)


@pytest.mark.parametrize("format_name", ("folder", "legacy"))
@pytest.mark.parametrize("format_name", ["folder", "legacy"])
def test_save_result_folder(
tmp_path: Path,
dummy_result: Result,
Expand All @@ -27,6 +27,7 @@ def test_save_result_folder(
"""Check all files exist."""

result_dir = tmp_path / "testresult"
assert not result_dir.exists()
save_paths = save_result(
result_path=str(result_dir), format_name=format_name, result=dummy_result
)
Expand All @@ -45,15 +46,15 @@ def test_save_result_folder(
assert (result_dir / wanted).as_posix() in save_paths


@pytest.mark.parametrize("format_name", ("folder", "legacy"))
@pytest.mark.parametrize("format_name", ["folder", "legacy"])
def test_save_result_folder_error_path_is_file(
tmp_path: Path,
dummy_result: Result,
format_name: Literal["folder", "legacy"],
):
"""Raise error if result_path is a file without extension and overwrite is true."""

result_dir = tmp_path / "testresult"
result_dir = tmp_path / "testresulterror"
result_dir.touch()

with pytest.raises(ValueError, match="The path '.+?' is not a directory."):
Expand Down
7 changes: 5 additions & 2 deletions glotaran/builtin/io/yml/test/test_save_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
@pytest.fixture(scope="session")
def dummy_result():
"""Dummy result for testing."""
print(SCHEME.data["dataset_1"])
scheme = replace(SCHEME, maximum_number_function_evaluations=1)
print(scheme.data["dataset_1"])
yield optimize(scheme, raise_exception=True)


Expand Down Expand Up @@ -58,4 +58,7 @@ def test_save_result_yml(
assert (result_dir / "dataset_1.nc").exists()

# We can't check equality due to numerical fluctuations
assert expected in (result_dir / "result.yml").read_text()
got = (result_dir / "result.yml").read_text()
print(got)
assert expected in got
save_result(result_path=tmp_path / "result.yml", result=dummy_result)
56 changes: 52 additions & 4 deletions glotaran/builtin/io/yml/yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO

from glotaran.builtin.io.folder.folder_plugin import save_result_to_folder
from glotaran.deprecation.modules.builtin_io_yml import model_spec_deprecations
from glotaran.deprecation.modules.builtin_io_yml import scheme_spec_deprecations
from glotaran.io import SAVING_OPTIONS_DEFAULT
from glotaran.io import ProjectIoInterface
from glotaran.io import register_project_io
from glotaran.io import save_result
from glotaran.model import Model
from glotaran.parameter import ParameterGroup
from glotaran.project import Result
Expand All @@ -20,12 +21,15 @@
from glotaran.utils.sanitize import sanitize_yaml

if TYPE_CHECKING:
from os import PathLike
from typing import Any
from typing import Mapping

from ruamel.yaml.nodes import ScalarNode
from ruamel.yaml.representer import BaseRepresenter

from glotaran.plugin_system.project_io_registration import SavingOptions


@register_project_io(["yml", "yaml", "yml_str"])
class YmlProjectIo(ProjectIoInterface):
Expand Down Expand Up @@ -127,22 +131,66 @@ def load_result(self, result_path: str) -> Result:
Result
:class:`Result` instance created from the saved format.
"""

spec = self._load_yml(result_path)
return fromdict(Result, spec, folder=Path(result_path).parent)

def save_result(self, result: Result, result_path: str):
"""Write a :class:`Result` instance to a spec file.
def save_result(
self,
result: Result,
result_path: str | PathLike[str],
format_name: str = None,
*,
allow_overwrite: bool = False,
saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT,
) -> list[str]:
"""Save the result to a given folder.

Returns a list with paths of all saved items.
The following files are saved if not configured otherwise:
* `result.md`: The result with the model formatted as markdown text.
* `model.yml`: Model spec file.
* `scheme.yml`: Scheme spec file.
* `initial_parameters.csv`: Initially used parameters.
* `optimized_parameters.csv`: The optimized parameter as csv file.
* `parameter_history.csv`: Parameter changes over the optimization
* `{dataset_label}.nc`: The result data for each dataset as NetCDF file.

Note
----
As a side effect it populates the file path properties of ``result`` which can be
used in other plugins (e.g. the ``yml`` save_result).

Parameters
----------
result : Result
:class:`Result` instance to write.
result_path : str | PathLike[str]
Path to write the result data to.
format_name : str
Format the result should be saved in, if not provided and it is a file
it will be inferred from the file extension.
allow_overwrite : bool
Whether or not to allow overwriting existing files, by default False
saving_options : SavingOptions
Options for the saved result.


Returns
-------
list[str]
List of file paths which were created.

Raises
------
ValueError
If ``result_path`` is a file.
"""
save_result(result, Path(result_path).parent.as_posix(), format_name="folder")
paths = save_result_to_folder(result, Path(result_path).parent.as_posix(), saving_options)
result_dict = asdict(result, folder=Path(result_path).parent)
write_dict(result_dict, file_name=result_path)
paths.append(result_path)
return paths

def _load_yml(self, file_name: str) -> dict[str, Any]:
yaml = YAML()
Expand Down
6 changes: 3 additions & 3 deletions glotaran/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
reexports functions from the pluginsystem from a common place.
"""

from glotaran.io.interface import SAVING_OPTIONS_DEFAULT
from glotaran.io.interface import SAVING_OPTIONS_MINIMAL
from glotaran.io.interface import DataIoInterface
from glotaran.io.interface import ProjectIoInterface
from glotaran.io.interface import SavingOptions
from glotaran.io.prepare_dataset import prepare_time_trace_dataset
from glotaran.plugin_system.data_io_registration import data_io_plugin_table
from glotaran.plugin_system.data_io_registration import get_dataloader
Expand All @@ -17,9 +20,6 @@
from glotaran.plugin_system.data_io_registration import save_dataset
from glotaran.plugin_system.data_io_registration import set_data_plugin
from glotaran.plugin_system.data_io_registration import show_data_io_method_help
from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT
from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_MINIMAL
from glotaran.plugin_system.project_io_registration import SavingOptions
from glotaran.plugin_system.project_io_registration import get_project_io_method
from glotaran.plugin_system.project_io_registration import load_model
from glotaran.plugin_system.project_io_registration import load_parameters
Expand Down
Loading