diff --git a/.github/workflows/CI_CD_actions.yml b/.github/workflows/CI_CD_actions.yml index 44fbc3686..02fb0fb0e 100644 --- a/.github/workflows/CI_CD_actions.yml +++ b/.github/workflows/CI_CD_actions.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Run pre-commit uses: pre-commit/action@v3.0.0 @@ -29,10 +29,10 @@ jobs: steps: - name: Check out repo uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install check manifest run: python -m pip install check-manifest 'setuptools>=62.4.0' - name: Run check manifest @@ -47,10 +47,10 @@ jobs: with: conda-channels: conda-forge activate-conda: false - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: | conda install -y pandoc @@ -72,10 +72,10 @@ jobs: with: conda-channels: conda-forge activate-conda: false - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: | conda install -y pandoc @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: | python -m pip install -U pip wheel @@ -117,7 +117,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - python-version: [3.8, 3.9, "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 @@ -147,10 +147,10 @@ jobs: needs: [test, docs, docs-notebooks] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: | python -m pip install -U pip wheel 'setuptools>=62.4.0' diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 22a2a1b7e..ee30378cc 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -17,10 +17,10 @@ jobs: name: Doxygen callgraph runs-on: ubuntu-latest steps: - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: pip install requests - name: Get ref to check out diff --git a/.github/workflows/full_benchmark.yml b/.github/workflows/full_benchmark.yml index 4dec6ef74..d5c0f9bad 100644 --- a/.github/workflows/full_benchmark.yml +++ b/.github/workflows/full_benchmark.yml @@ -15,10 +15,10 @@ jobs: with: fetch-depth: 0 ref: main - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install ASV run: | pip install asv virtualenv diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7d3ce33b2..c09b1b6fa 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -32,10 +32,10 @@ jobs: example_name: ${{fromJson(needs.create-example-list.outputs.example-list)}} steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install pyglotaran run: | pip install wheel @@ -93,10 +93,10 @@ jobs: echo "♻️ pyglotaran-examples commit: $(< comparison-results-current/example_commit_sha.txt)" echo "::endgroup::" - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: | diff --git a/.github/workflows/pr_benchmark.yml b/.github/workflows/pr_benchmark.yml index aefef56fc..8f9f7d7b4 100644 --- a/.github/workflows/pr_benchmark.yml +++ b/.github/workflows/pr_benchmark.yml @@ -12,10 +12,10 @@ jobs: with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: "3.10" - name: Install ASV run: | pip install 'asv!=0.5' virtualenv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1971af61a..9c29c32db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,7 @@ repos: types: [file] types_or: [python, pyi] args: [--py38-plus] + exclude: "glotaran.model.item" - repo: https://github.com/MarcoGorelli/absolufy-imports rev: v0.3.1 @@ -105,7 +106,7 @@ repos: args: - "--select=D,DAR" name: "flake8 lint docstrings" - files: "^glotaran/(plugin_system|utils|deprecation|testing|optimization|parameter|project|simulation|model/property.py|builtin/io/pandas)" + files: "^glotaran/(plugin_system|utils|deprecation|testing|optimization|parameter|project|simulation|model|builtin/io/pandas)" exclude: "docs|tests?/" additional_dependencies: [flake8-docstrings, darglint==1.8.0] @@ -113,9 +114,9 @@ repos: rev: v0.982 hooks: - id: mypy - files: "^glotaran/(plugin_system|utils|deprecation|testing|optimization|parameter|project|simulation|model/property.py|builtin/io/pandas)" + files: "^glotaran/(plugin_system|utils|deprecation|testing|optimization|parameter|project|simulation|model|builtin/io/pandas)" exclude: "docs" - additional_dependencies: [types-all] + additional_dependencies: [types-all, types-attrs] - repo: https://github.com/econchick/interrogate rev: 1.5.0 diff --git a/.sourcery.yaml b/.sourcery.yaml index bf2bb3115..a33701e99 100644 --- a/.sourcery.yaml +++ b/.sourcery.yaml @@ -5,7 +5,7 @@ refactor: - simplify-boolean-comparison - simplify-len-comparison - remove-unnecessary-cast - python_version: "3.8" + python_version: "3.10" metrics: quality_threshold: 25.0 diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 69552c6b5..153bccd67 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -131,7 +131,7 @@ Before you submit a pull request, check that it meets these guidelines: 1. The pull request should include tests. 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a `docstring`_. -3. The pull request should work for Python 3.8 and 3.9 +3. The pull request should work for Python 3.10 Check your Github Actions ``https://github.com//pyglotaran/actions`` and make sure that the tests pass for all supported Python versions. diff --git a/README.md b/README.md index 454d3e5d5..5db88ec5f 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ A common use case for the library is the analysis of time-resolved spectroscopy Prerequisites: -- Python 3.8, 3.9 or 3.10 +- Python 3.10 - On Windows only 64bit is supported Note for Windows Users: The easiest way to get python for Windows is via [Anaconda](https://www.anaconda.com/) diff --git a/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py b/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py index 5c5830900..16f191634 100644 --- a/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py +++ b/benchmark/benchmarks/integration/ex_two_datasets/benchmark.py @@ -1,5 +1,7 @@ from pathlib import Path +from glotaran import __version__ + try: # 0.4.0 -0.5.1 from glotaran.analysis.optimize import optimize @@ -28,15 +30,20 @@ class IntegrationTwoDatasets: def setup(self): dataset1 = load_dataset(SCRIPT_DIR / "data/data1.ascii") dataset2 = load_dataset(SCRIPT_DIR / "data/data2.ascii") - model = load_model(str(SCRIPT_DIR / "models/model.yml")) parameters = load_parameters(str(SCRIPT_DIR / "models/parameters.yml")) + addition_kwargs = {} + if int(__version__.split(".")[1]) < 7: + model = load_model(str(SCRIPT_DIR / "models/model_lt_0.7.0.yml")) + addition_kwargs["non_negative_least_squares"] = True + else: + model = load_model(str(SCRIPT_DIR / "models/model.yml")) self.scheme = Scheme( model, parameters, {"dataset1": dataset1, "dataset2": dataset2}, maximum_number_function_evaluations=11, - non_negative_least_squares=True, optimization_method="TrustRegionReflection", + **addition_kwargs, ) def time_optimize(self): diff --git a/benchmark/benchmarks/integration/ex_two_datasets/models/model.yml b/benchmark/benchmarks/integration/ex_two_datasets/models/model.yml index 0e6352975..e5dd0aa52 100644 --- a/benchmark/benchmarks/integration/ex_two_datasets/models/model.yml +++ b/benchmark/benchmarks/integration/ex_two_datasets/models/model.yml @@ -1,4 +1,4 @@ -type: kinetic-spectrum +default_megacomplex: decay dataset: dataset1: @@ -33,16 +33,16 @@ initial_concentration: irf: irf1: - type: spectral-multi-gaussian - center: [irf.center] - width: [irf.width] + type: gaussian + center: irf.center + width: irf.width irf1_no_dispersion: - type: spectral-multi-gaussian - center: [irf.center] - width: [irf.width] + type: gaussian + center: irf.center + width: irf.width # It works without equal_area_penalties but then the inputs cannot be estimated -equal_area_penalties: +clp_penalties: - type: equal_area source: s1 source_intervals: [[300, 3000]] @@ -57,7 +57,6 @@ equal_area_penalties: target_intervals: [[300, 3000]] parameter: area.1 weight: 0.1 - # Example of weight application: # weights: # - datasets: [dataset1, dataset2] diff --git a/benchmark/benchmarks/integration/ex_two_datasets/models/model_lt_0.7.0.yml b/benchmark/benchmarks/integration/ex_two_datasets/models/model_lt_0.7.0.yml new file mode 100644 index 000000000..664934ed7 --- /dev/null +++ b/benchmark/benchmarks/integration/ex_two_datasets/models/model_lt_0.7.0.yml @@ -0,0 +1,65 @@ +type: kinetic-spectrum + +dataset: + dataset1: + megacomplex: [complex1] + initial_concentration: input1 + irf: irf1 + scale: scale.1 + dataset2: + megacomplex: [complex1] + initial_concentration: input2 + irf: irf1 + scale: scale.2 + +megacomplex: + complex1: + k_matrix: [km1] + +k_matrix: + km1: + matrix: + (s1, s1): "rates.k1" + (s2, s2): "rates.k2" + (s3, s3): "rates.k3" + +initial_concentration: + input1: + compartments: [s1, s2, s3] + parameters: [inputs.1, inputs.2, inputs.3] + input2: + compartments: [s1, s2, s3] + parameters: [inputs.1, inputs.7, inputs.8] + +irf: + irf1: + type: spectral-multi-gaussian + center: [irf.center] + width: [irf.width] + irf1_no_dispersion: + type: spectral-multi-gaussian + center: [irf.center] + width: [irf.width] + +# It works without equal_area_penalties but then the inputs cannot be estimated +equal_area_penalties: + - type: equal_area + source: s1 + source_intervals: [[300, 3000]] + target: s2 + target_intervals: [[300, 3000]] + parameter: area.1 + weight: 0.1 + - type: equal_area + source: s1 + source_intervals: [[300, 3000]] + target: s3 + target_intervals: [[300, 3000]] + parameter: area.1 + weight: 0.1 +# Example of weight application: +# weights: +# - datasets: [dataset1, dataset2] +# global_interval: [100, 102] +# model_interval: [301, 502] +# value: 0.95 diff --git a/benchmark/pytest/analysis/test_optimization_group.py b/benchmark/pytest/analysis/test_optimization_group.py index 38d03d4a5..55c259647 100644 --- a/benchmark/pytest/analysis/test_optimization_group.py +++ b/benchmark/pytest/analysis/test_optimization_group.py @@ -10,7 +10,7 @@ from glotaran.model import Model from glotaran.model import megacomplex from glotaran.optimization.optimization_group import OptimizationGroup -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.testing.plugin_system import monkeypatch_plugin_registry @@ -32,11 +32,15 @@ np.ones((TEST_AXIS_GLOBAL_SIZE, TEST_AXIS_MODEL_SIZE)), coords=(("global", TEST_AXIS_GLOBAL.data), ("test", TEST_AXIS_MODEL.data)), ) -TEST_PARAMETERS = ParameterGroup.from_list([]) +TEST_PARAMETERS = Parameters.from_list([]) -@megacomplex(dimension="test", properties={"is_index_dependent": bool}) +@megacomplex() class BenchmarkMegacomplex(Megacomplex): + dimension: str = "test" + type: str = "benchmark" + is_index_dependent: bool + def calculate_matrix( self, dataset_model, @@ -60,10 +64,13 @@ def finalize_data( pass +BenchmarkModel = Model.create_class_from_megacomplexes([BenchmarkMegacomplex]) + + @monkeypatch_plugin_registry(test_megacomplex={"benchmark": BenchmarkMegacomplex}) def setup_model(index_dependent, link_clp): model_dict = { - "megacomplex": {"m1": {"is_index_dependent": index_dependent}}, + "megacomplex": {"m1": {"type": "benchmark", "is_index_dependent": index_dependent}}, "dataset_groups": {"default": {"link_clp": link_clp}}, "dataset": { "dataset1": {"megacomplex": ["m1"]}, @@ -71,11 +78,7 @@ def setup_model(index_dependent, link_clp): "dataset3": {"megacomplex": ["m1"]}, }, } - return Model.from_dict( - model_dict, - megacomplex_types={"benchmark": BenchmarkMegacomplex}, - default_megacomplex_type="benchmark", - ) + return BenchmarkModel(**model_dict) def setup_scheme(model): diff --git a/binder/environment.yml b/binder/environment.yml index e50faebc3..8032cae2f 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -6,7 +6,7 @@ dependencies: - jupyter-offlinenotebook=0.2 # Python Kernel - ipykernel>5.1 - - python=3.8 + - python=3.10 # update outdated repo2docker version - pip - pip: diff --git a/changelog.md b/changelog.md index fe0c10a5e..6fc6021f9 100644 --- a/changelog.md +++ b/changelog.md @@ -4,9 +4,15 @@ ## πŸš€ 0.7.0 (Unreleased) +### πŸ’₯ BREAKING CHANGE + +- πŸ’₯🚧 Dropped support for Python 3.8 and 3.9 and only support 3.10 (#1135) + ### ✨ Features - ✨ Add optimization history to result and iteration column to parameter history (#1134) +- ♻️ Complete refactor of model and parameter packages using attrs (#1135) + ### πŸ‘Œ Minor Improvements: @@ -21,8 +27,20 @@ ### πŸ—‘οΈ Deprecations (due in 0.9.0) +### πŸ—‘οΈ Deprecations (due in 0.8.0) + +- `.clp_area_penalties` -> `.clp_penalties` +- `glotaran.ParameterGroup` -> `glotaran.Parameters` + ### πŸ—‘οΈβŒ Deprecated functionality removed in this release +- `glotaran.project.Scheme(..., non_negative_least_squares=...)` +- `glotaran.project.Scheme(..., group=...)` +- `glotaran.project.Scheme(..., group_tolerance=...)` +- `.non-negative-least-squares: true` +- `.non-negative-least-squares: false` +- `glotaran.parameter.ParameterGroup.to_csv(file_name=parameters.csv)` + ### 🚧 Maintenance - πŸš‡πŸ©Ή Fix wrong comparison in pr_benchmark workflow (#1097) diff --git a/glotaran/builtin/io/pandas/csv.py b/glotaran/builtin/io/pandas/csv.py index 621cd0f64..51039461d 100644 --- a/glotaran/builtin/io/pandas/csv.py +++ b/glotaran/builtin/io/pandas/csv.py @@ -7,7 +7,7 @@ from glotaran.io import ProjectIoInterface from glotaran.io import register_project_io -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.utils.io import safe_dataframe_fillna from glotaran.utils.io import safe_dataframe_replace @@ -16,7 +16,7 @@ class CsvProjectIo(ProjectIoInterface): """Plugin for CSV data io.""" - def load_parameters(self, file_name: str, sep: str = ",") -> ParameterGroup: + def load_parameters(self, file_name: str, sep: str = ",") -> Parameters: """Load parameters from CSV file. Parameters @@ -28,27 +28,27 @@ def load_parameters(self, file_name: str, sep: str = ",") -> ParameterGroup: Returns ------- - :class:`ParameterGroup + :class:`Parameters """ df = pd.read_csv(file_name, skipinitialspace=True, na_values=["None", "none"], sep=sep) safe_dataframe_fillna(df, "minimum", -np.inf) safe_dataframe_fillna(df, "maximum", np.inf) - return ParameterGroup.from_dataframe(df, source=file_name) + return Parameters.from_dataframe(df, source=file_name) def save_parameters( self, - parameters: ParameterGroup, + parameters: Parameters, file_name: str, *, sep: str = ",", as_optimized: bool = True, replace_infinfinity: bool = True, ) -> None: - """Save a :class:`ParameterGroup` to a CSV file. + """Save a :class:`Parameters` to a CSV file. Parameters ---------- - parameters : ParameterGroup + parameters : Parameters Parameters to be saved to file. file_name : str File to write the parameters to. @@ -59,7 +59,7 @@ def save_parameters( replace_infinfinity : bool Weather to replace infinity values with empty strings. """ - df = parameters.to_dataframe(as_optimized=as_optimized) + df = parameters.to_dataframe() if replace_infinfinity is True: safe_dataframe_replace(df, "minimum", -np.inf, "") safe_dataframe_replace(df, "maximum", np.inf, "") diff --git a/glotaran/builtin/io/pandas/test/data/reference_parameters.csv b/glotaran/builtin/io/pandas/test/data/reference_parameters.csv index 730ca0a46..53d3a880f 100644 --- a/glotaran/builtin/io/pandas/test/data/reference_parameters.csv +++ b/glotaran/builtin/io/pandas/test/data/reference_parameters.csv @@ -1,8 +1,8 @@ -label,value,expression,minimum,maximum,non-negative,vary,standard-error -pure_list.1,1.0,None,,,False,True,None -pure_list.2,2.0,None,,,False,True,None -list_with_options.1,3.0,None,,,False,False,None -list_with_options.2,4.0,None,,,False,False,None -verbose_list.all_defaults,5.0,None,,,False,True,None -verbose_list.no_defaults,6.0,None,,,True,False,None -verbose_list.expression_only,11.0,$verbose_list.all_defaults + $verbose_list.no_defaults,,,False,False,None +label,value,standard_error,expression,maximum,minimum,non_negative,vary +pure_list.1,1.0,None,None,,,False,True +pure_list.2,2.0,None,None,,,False,True +list_with_options.1,3.0,None,None,,,False,False +list_with_options.2,4.0,None,None,,,False,False +verbose_list.all_defaults,5.0,None,None,,,False,True +verbose_list.no_defaults,6.0,None,None,1.0,-1.0,True,False +verbose_list.expression_only,11.0,None,$verbose_list.all_defaults + $verbose_list.no_defaults,,,False,False diff --git a/glotaran/builtin/io/pandas/test/data/reference_parameters.ods b/glotaran/builtin/io/pandas/test/data/reference_parameters.ods index 7750cc793..48245cce9 100644 Binary files a/glotaran/builtin/io/pandas/test/data/reference_parameters.ods and b/glotaran/builtin/io/pandas/test/data/reference_parameters.ods differ diff --git a/glotaran/builtin/io/pandas/test/data/reference_parameters.tsv b/glotaran/builtin/io/pandas/test/data/reference_parameters.tsv index 331310219..1f215d88a 100644 --- a/glotaran/builtin/io/pandas/test/data/reference_parameters.tsv +++ b/glotaran/builtin/io/pandas/test/data/reference_parameters.tsv @@ -1,8 +1,8 @@ -label value expression minimum maximum non-negative vary standard-error -pure_list.1 1.0 None False True None -pure_list.2 2.0 None False True None -list_with_options.1 3.0 None False False None -list_with_options.2 4.0 None False False None -verbose_list.all_defaults 5.0 None False True None -verbose_list.no_defaults 6.0 None True False None -verbose_list.expression_only 11.0 $verbose_list.all_defaults + $verbose_list.no_defaults False False None +label value standard_error expression maximum minimum non_negative vary +pure_list.1 1.0 None None False True +pure_list.2 2.0 None None False True +list_with_options.1 3.0 None None False False +list_with_options.2 4.0 None None False False +verbose_list.all_defaults 5.0 None None False True +verbose_list.no_defaults 6.0 None None 1.0 -1.0 True False +verbose_list.expression_only 11.0 None $verbose_list.all_defaults + $verbose_list.no_defaults False False diff --git a/glotaran/builtin/io/pandas/test/data/reference_parameters.xlsx b/glotaran/builtin/io/pandas/test/data/reference_parameters.xlsx index 9c35268dd..bcd7cf2e6 100644 Binary files a/glotaran/builtin/io/pandas/test/data/reference_parameters.xlsx and b/glotaran/builtin/io/pandas/test/data/reference_parameters.xlsx differ diff --git a/glotaran/builtin/io/pandas/test/data/reference_parameters.yaml b/glotaran/builtin/io/pandas/test/data/reference_parameters.yaml index c6d1c9a02..233e2c1ee 100644 --- a/glotaran/builtin/io/pandas/test/data/reference_parameters.yaml +++ b/glotaran/builtin/io/pandas/test/data/reference_parameters.yaml @@ -1,8 +1,15 @@ pure_list: [1.0, 2.0] -list_with_options: [3.0, 4.0, {vary: False}] +list_with_options: [3.0, 4.0, { vary: False }] verbose_list: - ["all_defaults", 5.0] - - ["no_defaults", 6.0, {non-negative: True, vary: False, minimum: -1, maximum: 1}] - - ["expression_only", {expr: $verbose_list.all_defaults + $verbose_list.no_defaults}] + - [ + "no_defaults", + 6.0, + { non-negative: True, vary: False, min: -1.0, max: 1.0 }, + ] + - [ + "expression_only", + { expr: $verbose_list.all_defaults + $verbose_list.no_defaults }, + ] diff --git a/glotaran/builtin/io/pandas/test/test_pandas_parameters.py b/glotaran/builtin/io/pandas/test/test_pandas_parameters.py index 0b5b1d29b..8e07da47f 100644 --- a/glotaran/builtin/io/pandas/test/test_pandas_parameters.py +++ b/glotaran/builtin/io/pandas/test/test_pandas_parameters.py @@ -9,7 +9,7 @@ from glotaran.io import load_parameters from glotaran.io import save_parameters -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters PANDAS_TEST_DATA = Path(__file__).parent / "data" PATH_XLSX = PANDAS_TEST_DATA / "reference_parameters.xlsx" @@ -19,13 +19,13 @@ @pytest.fixture(scope="module") -def yaml_reference() -> ParameterGroup: +def yaml_reference() -> Parameters: """Fixture for yaml reference data.""" return load_parameters(PANDAS_TEST_DATA / "reference_parameters.yaml") @pytest.mark.parametrize("reference_path", (PATH_XLSX, PATH_ODS, PATH_CSV, PATH_TSV)) -def test_references(yaml_reference: ParameterGroup, reference_path: Path): +def test_references(yaml_reference: Parameters, reference_path: Path): """References are the same""" result = load_parameters(reference_path) assert result == yaml_reference @@ -36,7 +36,7 @@ def test_references(yaml_reference: ParameterGroup, reference_path: Path): (("xlsx", PATH_XLSX), ("ods", PATH_ODS), ("csv", PATH_CSV), ("tsv", PATH_TSV)), ) def test_roundtrips( - yaml_reference: ParameterGroup, tmp_path: Path, format_name: str, reference_path: Path + yaml_reference: Parameters, tmp_path: Path, format_name: str, reference_path: Path ): """Roundtrip via save and load have the same data.""" format_reference = load_parameters(reference_path) @@ -62,31 +62,10 @@ def test_roundtrips( ) -@pytest.mark.parametrize("format_name", ("xlsx", "ods", "csv", "tsv")) -def test_as_optimized_false(yaml_reference: ParameterGroup, tmp_path: Path, format_name: str): - """Column 'standard-error' is missing if as_optimized==False""" - parameter_path = tmp_path / f"test_parameters.{format_name}" - save_parameters( - file_name=parameter_path, - format_name=format_name, - parameters=yaml_reference, - as_optimized=False, - ) - - if format_name in {"csv", "tsv"}: - assert "standard-error" not in parameter_path.read_text().splitlines()[0] - else: - assert ( - "standard-error" - not in pd.read_excel(parameter_path, na_values=["None", "none"]).columns - ) - - @pytest.mark.parametrize("format_name,sep", (("csv", ","), ("tsv", "\t"))) def test_replace_infinfinity( - yaml_reference: ParameterGroup, tmp_path: Path, format_name: str, sep: str + yaml_reference: Parameters, tmp_path: Path, format_name: str, sep: str ): - """Column 'standard-error' is missing if as_optimized==False""" parameter_path = tmp_path / f"test_parameters.{format_name}" save_parameters( file_name=parameter_path, @@ -95,8 +74,9 @@ def test_replace_infinfinity( replace_infinfinity=False, ) df = pd.read_csv(parameter_path, sep=sep) - assert all(df["minimum"] == -np.inf) + df = df[df["label"] != "verbose_list.no_defaults"] assert all(df["maximum"] == np.inf) + assert all(df["minimum"] == -np.inf) first_data_line = parameter_path.read_text().splitlines()[1] assert f"{sep}-inf" in first_data_line diff --git a/glotaran/builtin/io/pandas/tsv.py b/glotaran/builtin/io/pandas/tsv.py index 3d49b8580..e0f2c3bc1 100644 --- a/glotaran/builtin/io/pandas/tsv.py +++ b/glotaran/builtin/io/pandas/tsv.py @@ -10,14 +10,14 @@ from glotaran.io import save_parameters if TYPE_CHECKING: - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters @register_project_io(["tsv"]) class TsvProjectIo(ProjectIoInterface): """Plugin for TSV data io.""" - def load_parameters(self, file_name: str) -> ParameterGroup: + def load_parameters(self, file_name: str) -> Parameters: """Load parameters from TSV file. Parameters @@ -27,28 +27,25 @@ def load_parameters(self, file_name: str) -> ParameterGroup: Returns ------- - :class:`ParameterGroup + :class:`Parameters` """ return load_parameters(file_name, format_name="csv", sep="\t") def save_parameters( self, - parameters: ParameterGroup, + parameters: Parameters, file_name: str, *, - as_optimized: bool = True, replace_infinfinity: bool = True, ) -> None: - """Save a :class:`ParameterGroup` to a TSV file. + """Save a :class:`Parameters` to a TSV file. Parameters ---------- - parameters : ParameterGroup + parameters : Parameters Parameters to be saved to file. file_name : str File to write the parameters to. - as_optimized : bool - Whether to include properties which are the result of optimization. replace_infinfinity : bool Weather to replace infinity values with empty strings. """ @@ -57,6 +54,5 @@ def save_parameters( file_name, format_name="csv", sep="\t", - as_optimized=as_optimized, replace_infinfinity=replace_infinfinity, ) diff --git a/glotaran/builtin/io/pandas/xlsx.py b/glotaran/builtin/io/pandas/xlsx.py index 5a0087b31..700a41a0c 100644 --- a/glotaran/builtin/io/pandas/xlsx.py +++ b/glotaran/builtin/io/pandas/xlsx.py @@ -7,7 +7,7 @@ from glotaran.io import ProjectIoInterface from glotaran.io import register_project_io -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.utils.io import safe_dataframe_fillna from glotaran.utils.io import safe_dataframe_replace @@ -16,7 +16,7 @@ class ExcelProjectIo(ProjectIoInterface): """Plugin for Excel like data io.""" - def load_parameters(self, file_name: str) -> ParameterGroup: + def load_parameters(self, file_name: str) -> Parameters: """Load parameters from XLSX file. Parameters @@ -26,28 +26,24 @@ def load_parameters(self, file_name: str) -> ParameterGroup: Returns ------- - :class:`ParameterGroup + :class:`Parameters` """ df = pd.read_excel(file_name, na_values=["None", "none"]) safe_dataframe_fillna(df, "minimum", -np.inf) safe_dataframe_fillna(df, "maximum", np.inf) - return ParameterGroup.from_dataframe(df, source=file_name) + return Parameters.from_dataframe(df, source=file_name) - def save_parameters( - self, parameters: ParameterGroup, file_name: str, *, as_optimized: bool = True - ): - """Save a :class:`ParameterGroup` to a Excel file. + def save_parameters(self, parameters: Parameters, file_name: str): + """Save a :class:`Parameters` to a Excel file. Parameters ---------- - parameters : ParameterGroup + parameters : Parameters Parameters to be saved to file. file_name : str File to write the parameters to. - as_optimized : bool - Whether to include properties which are the result of optimization. """ - df = parameters.to_dataframe(as_optimized=as_optimized) + df = parameters.to_dataframe() safe_dataframe_replace(df, "minimum", -np.inf, "") safe_dataframe_replace(df, "maximum", np.inf, "") df.to_excel(file_name, na_rep="None", index=False) diff --git a/glotaran/builtin/io/yml/test/test_load_parameters.py b/glotaran/builtin/io/yml/test/test_load_parameters.py new file mode 100644 index 000000000..485fdf795 --- /dev/null +++ b/glotaran/builtin/io/yml/test/test_load_parameters.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import numpy as np + +from glotaran.io import load_parameters + + +def test_parameter_group_copy(): + params = """ + a: + - ["foo", 1, {non-negative: true, min: -1, max: 1, vary: false}] + - 4 + - 5 + b: + - 7 + - 8 + """ + parameters = load_parameters(params, format_name="yml_str") + + assert parameters.get("a.foo").value == 1 + assert parameters.get("a.foo").non_negative + assert parameters.get("a.foo").minimum == -1 + assert parameters.get("a.foo").maximum == 1 + assert not parameters.get("a.foo").vary + + assert parameters.get("a.2").value == 4 + assert not parameters.get("a.2").non_negative + assert parameters.get("a.2").minimum == -np.inf + assert parameters.get("a.2").maximum == np.inf + assert parameters.get("a.2").vary + + assert parameters.get("a.3").value == 5 + + assert parameters.get("b.1").value == 7 + assert not parameters.get("b.1").non_negative + assert parameters.get("b.1").minimum == -np.inf + assert parameters.get("b.1").maximum == np.inf + assert parameters.get("b.1").vary + + assert parameters.get("b.2").value == 8 diff --git a/glotaran/builtin/io/yml/test/test_model_parser.py b/glotaran/builtin/io/yml/test/test_model_parser.py index 712b77eee..caf37be89 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -2,7 +2,6 @@ from os.path import dirname from os.path import join -import numpy as np import pytest from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex @@ -11,12 +10,10 @@ from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeGaussian from glotaran.io import load_model from glotaran.model import DatasetModel -from glotaran.model import Model +from glotaran.model import EqualAreaPenalty +from glotaran.model import OnlyConstraint from glotaran.model import Weight -from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import OnlyConstraint -from glotaran.model.constraint import ZeroConstraint -from glotaran.parameter import ParameterGroup +from glotaran.model import ZeroConstraint THIS_DIR = dirname(abspath(__file__)) @@ -29,13 +26,6 @@ def model(): return m -def test_correct_model(model): - assert isinstance(model, Model) - assert "decay" == model.default_megacomplex - assert "decay" in model.megacomplex_types - assert "spectral" in model.megacomplex_types - - def test_dataset(model): assert len(model.dataset) == 2 @@ -46,7 +36,7 @@ def test_dataset(model): assert dataset.megacomplex == ["cmplx1"] assert dataset.initial_concentration == "inputD1" assert dataset.irf == "irf1" - assert dataset.scale.full_label == "1" + assert dataset.scale == "1" assert "dataset2" in model.dataset dataset = model.dataset["dataset2"] @@ -55,7 +45,7 @@ def test_dataset(model): assert dataset.megacomplex == ["cmplx2"] assert dataset.initial_concentration == "inputD2" assert dataset.irf == "irf2" - assert dataset.scale.full_label == "2" + assert dataset.scale == "2" assert dataset.spectral_axis_scale == 1e7 assert dataset.spectral_axis_inverted @@ -76,14 +66,15 @@ def test_constraints(model): def test_penalties(model): - assert len(model.clp_area_penalties) == 1 - assert all(isinstance(c, EqualAreaPenalty) for c in model.clp_area_penalties) - eac = model.clp_area_penalties[0] + assert len(model.clp_penalties) == 1 + assert all(isinstance(c, EqualAreaPenalty) for c in model.clp_penalties) + eac = model.clp_penalties[0] + assert eac.type == "equal_area" assert eac.source == "s3" assert eac.source_intervals == [[670, 810]] assert eac.target == "s2" assert eac.target_intervals == [[670, 810]] - assert eac.parameter.full_label == "55" + assert eac.parameter == "55" assert eac.weight == 0.0016 @@ -108,7 +99,7 @@ def test_initial_concentration(model): assert initial_concentration.compartments == ["s1", "s2", "s3"] assert isinstance(initial_concentration, InitialConcentration) assert initial_concentration.label == label - assert [p.full_label for p in initial_concentration.parameters] == ["1", "2", "3"] + assert initial_concentration.parameters == ["1", "2", "3"] def test_irf(model): @@ -121,17 +112,17 @@ def test_irf(model): assert isinstance(irf, IrfMultiGaussian) assert irf.label == label want = ["1"] if i == 1 else ["1", "2"] - assert [p.full_label for p in irf.center] == want + assert irf.center == want want = ["2"] if i == 1 else ["3", "4"] - assert [p.full_label for p in irf.width] == want + assert irf.width == want if i == 2: want = ["3"] if i == 1 else ["5", "6"] - assert [p.full_label for p in irf.center_dispersion_coefficients] == want + assert irf.center_dispersion_coefficients == want want = ["7", "8"] - assert [p.full_label for p in irf.width_dispersion_coefficients] == want + assert irf.width_dispersion_coefficients == want want = ["9"] - assert [p.full_label for p in irf.scale] == want + assert irf.scale == want assert irf.normalize == (i == 1) if i == 2: @@ -144,14 +135,15 @@ def test_irf(model): def test_k_matrices(model): assert "km1" in model.k_matrix - parameter = ParameterGroup.from_list([1, 2, 3, 4, 5, 6, 7]) - print(model.k_matrix["km1"].fill(model, parameter).matrix) - reduced = model.k_matrix["km1"].fill(model, parameter).reduced(["s1", "s2", "s3", "s4"]) - print(parameter) - print(reduced) - wanted = np.asarray([[1, 3, 5, 7], [2, 0, 0, 0], [4, 0, 0, 0], [6, 0, 0, 0]]) - print(wanted) - assert np.array_equal(reduced, wanted) + assert model.k_matrix["km1"].matrix == { + ("s1", "s1"): "1", + ("s2", "s1"): "2", + ("s1", "s2"): "3", + ("s3", "s1"): "4", + ("s1", "s3"): "5", + ("s4", "s1"): "6", + ("s1", "s4"): "7", + } def test_weight(model): @@ -170,9 +162,9 @@ def test_shapes(model): shape = model.shape["shape1"] assert isinstance(shape, SpectralShapeGaussian) - assert shape.amplitude.full_label == "shape.1" - assert shape.location.full_label == "shape.2" - assert shape.width.full_label == "shape.3" + assert shape.amplitude == "shape.1" + assert shape.location == "shape.2" + assert shape.width == "shape.3" def test_megacomplexes(model): diff --git a/glotaran/builtin/io/yml/test/test_model_spec.yml b/glotaran/builtin/io/yml/test/test_model_spec.yml index 55a7817cf..a8f762b0b 100644 --- a/glotaran/builtin/io/yml/test/test_model_spec.yml +++ b/glotaran/builtin/io/yml/test/test_model_spec.yml @@ -5,40 +5,40 @@ dataset: megacomplex: [cmplx1] initial_concentration: inputD1 irf: irf1 - scale: 1 + scale: "1" dataset2: megacomplex: [cmplx2] initial_concentration: inputD2 irf: irf2 - scale: 2 + scale: "2" spectral_axis_scale: 1e7 spectral_axis_inverted: true irf: irf1: type: multi-gaussian - center: [1] - width: [2] + center: ["1"] + width: ["2"] irf2: type: spectral-multi-gaussian - center: [1, 2] - width: [3, 4] - scale: [9] + center: ["1", "2"] + width: ["3", "4"] + scale: ["9"] normalize: false backsweep: true - backsweep_period: 55 - dispersion_center: 55 - center_dispersion_coefficients: [5, 6] - width_dispersion_coefficients: [7, 8] + backsweep_period: "55" + dispersion_center: "55" + center_dispersion_coefficients: ["5", "6"] + width_dispersion_coefficients: ["7", "8"] model_dispersion_with_wavenumber: true initial_concentration: inputD1: compartments: [s1, s2, s3] - parameters: [1, 2, 3] + parameters: ["1", "2", "3"] inputD2: compartments: [s1, s2, s3] - parameters: [1, 2, 3] + parameters: ["1", "2", "3"] # Convention matrix notation column = source, row = target compartment # (2,1) means from 1 to 2 @@ -83,19 +83,19 @@ clp_constraints: - [1, 100] - [2, 200] -clp_area_penalties: - - type: equal_area +clp_penalties: + - type: "equal_area" source: s3 source_intervals: [[670, 810]] target: s2 target_intervals: [[670, 810]] - parameter: 55 + parameter: "55" weight: 0.0016 clp_relations: - source: s1 target: s2 - parameter: 8 + parameter: "8" interval: [[1, 100], [2, 200]] weights: diff --git a/glotaran/builtin/io/yml/test/test_save_model.py b/glotaran/builtin/io/yml/test/test_save_model.py index 8b065d3ec..ece944953 100644 --- a/glotaran/builtin/io/yml/test/test_save_model.py +++ b/glotaran/builtin/io/yml/test/test_save_model.py @@ -10,21 +10,20 @@ from pathlib import Path -want = """default_megacomplex: decay-sequential +want = """\ +clp_penalties: [] +clp_constraints: [] +clp_relations: [] dataset_groups: default: + label: default residual_function: variable_projection link_clp: null -irf: - gaussian_irf: - type: gaussian - center: irf.center - width: irf.width - normalize: true - backsweep: false +weights: [] megacomplex: megacomplex_sequential_decay: - type: decay-sequential + label: megacomplex_sequential_decay + dimension: time compartments: - species_1 - species_2 @@ -33,12 +32,29 @@ - rates.species_1 - rates.species_2 - rates.species_3 - dimension: time + type: decay-sequential +irf: + gaussian_irf: + label: gaussian_irf + scale: null + shift: null + normalize: true + backsweep: false + backsweep_period: null + type: gaussian + center: irf.center + width: irf.width dataset: dataset_1: + label: dataset_1 group: default + force_index_dependent: false megacomplex: - megacomplex_sequential_decay + megacomplex_scale: null + global_megacomplex: null + global_megacomplex_scale: null + scale: null irf: gaussian_irf """ diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index d8b9dc155..4c004f4af 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -15,7 +15,8 @@ from glotaran.io import save_result from glotaran.io import save_scheme from glotaran.model import Model -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters +from glotaran.plugin_system.megacomplex_registration import get_megacomplex from glotaran.project.dataclass_helpers import asdict from glotaran.project.dataclass_helpers import fromdict from glotaran.project.project import Result @@ -48,7 +49,10 @@ def load_model(self, file_name: str) -> Model: spec = sanitize_yaml(spec) - default_megacomplex = spec.get("default_megacomplex") + if "megacomplex" not in spec: + raise ValueError("No megacomplex defined in model") + + default_megacomplex = spec.pop("default_megacomplex", None) if default_megacomplex is None and any( "type" not in m for m in spec["megacomplex"].values() @@ -58,10 +62,13 @@ def load_model(self, file_name: str) -> Model: "at least one megacomplex does not have a type." ) - if "megacomplex" not in spec: - raise ValueError("No megacomplex defined in model") + spec["megacomplex"] = { + label: m | {"type": default_megacomplex} if "type" not in m else m + for label, m in spec["megacomplex"].items() + } - return Model.from_dict(spec, megacomplex_types=None, default_megacomplex_type=None) + megacomplex_types = {get_megacomplex(m["type"]) for m in spec["megacomplex"].values()} + return Model.create_class_from_megacomplexes(megacomplex_types)(**spec) def save_model(self, model: Model, file_name: str): """Save a Model instance to a spec file. @@ -85,24 +92,24 @@ def save_model(self, model: Model, file_name: str): item[prop_name] = {f"{k}": v for k, v in zip(keys, prop.values())} write_dict(model_dict, file_name=file_name) - def load_parameters(self, file_name: str) -> ParameterGroup: - """Create a ParameterGroup instance from the specs defined in a file. + def load_parameters(self, file_name: str) -> Parameters: + """Create parameters instance from the specs defined in a file. Parameters ---------- file_name : str File containing the parameter specs. Returns ------- - ParameterGroup - ParameterGroup instance created from the file. + Parameters + Parameters instance created from the file. """ spec = self._load_yml(file_name) if isinstance(spec, list): - return ParameterGroup.from_list(spec) + return Parameters.from_list(spec) else: - return ParameterGroup.from_dict(spec) + return Parameters.from_dict(spec) def load_scheme(self, file_name: str) -> Scheme: spec = self._load_yml(file_name) diff --git a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py index 41286f8ac..4223bc67e 100644 --- a/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/baseline_megacomplex.py @@ -8,8 +8,10 @@ from glotaran.model import megacomplex -@megacomplex(unique=True, register_as="baseline") +@megacomplex(unique=True) class BaselineMegacomplex(Megacomplex): + type: str = "baseline" + def calculate_matrix( self, dataset_model: DatasetModel, diff --git a/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py index e1ff6f128..404865618 100644 --- a/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py +++ b/glotaran/builtin/megacomplexes/baseline/test/test_baseline_megacomplex.py @@ -1,55 +1,29 @@ import numpy as np from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex -from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.model import Model +from glotaran.model import fill_item from glotaran.optimization.matrix_provider import MatrixProvider -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters def test_baseline(): - model = Model.from_dict( - { - "initial_concentration": { - "j1": {"compartments": ["s1"], "parameters": ["2"]}, - }, - "megacomplex": { - "mc1": {"type": "decay", "k_matrix": ["k1"]}, - "mc2": {"type": "baseline", "dimension": "time"}, - }, - "k_matrix": { - "k1": { - "matrix": { - ("s1", "s1"): "1", - } - } - }, - "dataset": { - "dataset1": { - "initial_concentration": "j1", - "megacomplex": ["mc1", "mc2"], - }, - }, - }, - megacomplex_types={"decay": DecayMegacomplex, "baseline": BaselineMegacomplex}, - ) - - parameter = ParameterGroup.from_list( - [ - 101e-4, - [1, {"vary": False, "non-negative": False}], - [42, {"vary": False, "non-negative": False}], - ] + model = Model.create_class_from_megacomplexes([BaselineMegacomplex])( + **{ + "megacomplex": {"m": {"type": "baseline", "dimension": "time"}}, + "dataset": {"dataset1": {"megacomplex": ["m"]}}, + } ) + parameters = Parameters({}) time = np.asarray(np.arange(0, 50, 1.5)) pixel = np.asarray([0]) - dataset_model = model.dataset["dataset1"].fill(model, parameter) + dataset_model = fill_item(model.dataset["dataset1"], model, parameters) matrix = MatrixProvider.calculate_dataset_matrix(dataset_model, None, pixel, time) compartments = matrix.clp_labels - assert len(compartments) == 2 + assert len(compartments) == 1 assert "dataset1_baseline" in compartments - assert matrix.matrix.shape == (time.size, 2) - assert np.all(matrix.matrix[:, 1] == 1) + assert matrix.matrix.shape == (time.size, 1) + assert np.all(matrix.matrix[:, 0] == 1) diff --git a/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py b/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py index dd478cde0..239390d2c 100644 --- a/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py +++ b/glotaran/builtin/megacomplexes/clp_guide/clp_guide_megacomplex.py @@ -8,8 +8,11 @@ from glotaran.model import megacomplex -@megacomplex(exclusive=True, register_as="clp-guide", properties={"target": str}) +@megacomplex(exclusive=True) class ClpGuideMegacomplex(Megacomplex): + type: str = "clp-guide" + target: str + def calculate_matrix( self, dataset_model: DatasetModel, diff --git a/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py b/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py index 00365a0f8..6d95a4dad 100644 --- a/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py +++ b/glotaran/builtin/megacomplexes/clp_guide/test/test_clp_guide_megacomplex.py @@ -5,15 +5,17 @@ from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import create_gaussian_clp from glotaran.model import Model from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation.simulation import simulate def test_clp_guide(): - model = Model.from_dict( - { + model = Model.create_class_from_megacomplexes( + [DecaySequentialMegacomplex, ClpGuideMegacomplex] + )( + **{ "dataset_groups": {"default": {"link_clp": True}}, "megacomplex": { "mc1": { @@ -28,16 +30,12 @@ def test_clp_guide(): "dataset2": {"megacomplex": ["mc2"]}, }, }, - megacomplex_types={ - "decay-sequential": DecaySequentialMegacomplex, - "clp-guide": ClpGuideMegacomplex, - }, ) - initial_parameters = ParameterGroup.from_list( + initial_parameters = Parameters.from_list( [101e-5, 501e-4, [1, {"vary": False, "non-negative": False}]] ) - wanted_parameters = ParameterGroup.from_list( + wanted_parameters = Parameters.from_list( [101e-4, 501e-3, [1, {"vary": False, "non-negative": False}]] ) @@ -59,6 +57,5 @@ def test_clp_guide(): ) result = optimize(scheme) print(result.optimized_parameters) - - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py index bdbc7eae8..ca12fb79a 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/coherent_artifact_megacomplex.py @@ -5,30 +5,25 @@ import numpy as np import xarray as xr -from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import DecayDatasetModel from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.builtin.megacomplexes.decay.util import retrieve_irf from glotaran.model import DatasetModel from glotaran.model import Megacomplex from glotaran.model import ModelError +from glotaran.model import ParameterType +from glotaran.model import is_dataset_model_index_dependent from glotaran.model import megacomplex -from glotaran.parameter import Parameter - - -@megacomplex( - dimension="time", - unique=True, - properties={ - "order": {"type": int}, - "width": {"type": Parameter, "allow_none": True}, - }, - dataset_model_items={ - "irf": {"type": Irf, "allow_none": True}, - }, - register_as="coherent-artifact", -) + + +@megacomplex(dataset_model_type=DecayDatasetModel, unique=True) class CoherentArtifactMegacomplex(Megacomplex): + dimension: str = "time" + type: str = "coherent-artifact" + order: int + width: ParameterType | None = None + def calculate_matrix( self, dataset_model: DatasetModel, @@ -73,7 +68,7 @@ def finalize_data( model_dimension = dataset.attrs["model_dimension"] dataset.coords["coherent_artifact_order"] = np.arange(1, self.order + 1) response_dimensions = (model_dimension, "coherent_artifact_order") - if dataset_model.is_index_dependent() is True: + if is_dataset_model_index_dependent(dataset_model): response_dimensions = (global_dimension, *response_dimensions) dataset["coherent_artifact_response"] = ( response_dimensions, diff --git a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py index 519e8e259..fcbff5e0b 100644 --- a/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py +++ b/glotaran/builtin/megacomplexes/coherent_artifact/test/test_coherent_artifact.py @@ -5,9 +5,10 @@ from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.model import Model +from glotaran.model import fill_item from glotaran.optimization.matrix_provider import MatrixProvider from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -34,7 +35,7 @@ def test_coherent_artifact(spectral_dependence: str): }, "irf": { "irf1": { - "type": "spectral-multi-gaussian", + "type": "multi-gaussian", "center": ["irf_center"], "width": ["irf_width"], }, @@ -57,8 +58,9 @@ def test_coherent_artifact(spectral_dependence: str): irf_spec = model_dict["irf"]["irf1"] if spectral_dependence == "dispersed": + irf_spec["type"] = "spectral-multi-gaussian" irf_spec["dispersion_center"] = "irf_dispc" - irf_spec["center_dispersion"] = ["irf_disp1", "irf_disp2"] + irf_spec["center_dispersion_coefficients"] = ["irf_disp1", "irf_disp2"] parameter_list += [ ["irf_dispc", 300, {"vary": False, "non-negative": False}], @@ -74,20 +76,16 @@ def test_coherent_artifact(spectral_dependence: str): ["irf_shift3", 2], ] - model = Model.from_dict( - model_dict.copy(), - megacomplex_types={ - "decay": DecayMegacomplex, - "coherent-artifact": CoherentArtifactMegacomplex, - }, + model = Model.create_class_from_megacomplexes([DecayMegacomplex, CoherentArtifactMegacomplex])( + **model_dict ) - parameters = ParameterGroup.from_list(parameter_list) + parameters = Parameters.from_list(parameter_list) time = np.arange(0, 50, 1.5) spectral = np.asarray([200, 300, 400]) - dataset_model = model.dataset["dataset1"].fill(model, parameters) + dataset_model = fill_item(model.dataset["dataset1"], model, parameters) matrix = MatrixProvider.calculate_dataset_matrix(dataset_model, 0, spectral, time) compartments = matrix.clp_labels @@ -123,8 +121,8 @@ def test_coherent_artifact(spectral_dependence: str): result = optimize(scheme) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, parameters.get(label).value, rtol=1e-8) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] assert np.array_equal(data.time, resultdata.time) diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py index fafb0e430..31f470ea0 100644 --- a/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py +++ b/glotaran/builtin/megacomplexes/damped_oscillation/damped_oscillation_megacomplex.py @@ -1,48 +1,68 @@ from __future__ import annotations -from typing import List - import numba as nb import numpy as np import xarray as xr from scipy.special import erf -from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import DecayDatasetModel from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian from glotaran.model import DatasetModel +from glotaran.model import ItemIssue from glotaran.model import Megacomplex from glotaran.model import Model +from glotaran.model import ParameterType +from glotaran.model import attribute from glotaran.model import megacomplex -from glotaran.model.item import model_item_validator -from glotaran.parameter import Parameter - - -@megacomplex( - dimension="time", - dataset_model_items={ - "irf": {"type": Irf, "allow_none": True}, - }, - properties={ - "labels": List[str], - "frequencies": List[Parameter], - "rates": List[Parameter], - }, - register_as="damped-oscillation", -) -class DampedOscillationMegacomplex(Megacomplex): - @model_item_validator(False) - def ensure_oscillation_parameter(self, model: Model) -> list[str]: +from glotaran.parameter import Parameters + + +class OscillationParameterIssue(ItemIssue): + def __init__(self, label: str, len_labels: int, len_frequencies: int, len_rates: int): + self._label = label + self._len_labels = len_labels + self._len_frequencies = len_frequencies + self._len_rates = len_rates + + def to_string(self) -> str: + return ( + f"Size of labels ({self.len_labels}), frequencies ({self.len_frequencies}) " + f"and rates ({self.len_rates}) does not match for damped oscillation " + f"megacomplex '{self.label}'." + ) - problems = [] - if len(self.labels) != len(self.frequencies) or len(self.labels) != len(self.rates): - problems.append( - f"Size of labels ({len(self.labels)}), frequencies ({len(self.frequencies)}) " - f"and rates ({len(self.rates)}) does not match for damped oscillation " - f"megacomplex '{self.label}'." +def validate_oscillation_parameter( + labels: list[str], + damped_oscillation: DampedOscillationMegacomplex, + model: Model, + parameters: Parameters | None, +) -> list[ItemIssue]: + issues = [] + + len_labels, len_frequencies, len_rates = ( + len(damped_oscillation.labels), + len(damped_oscillation.frequencies), + len(damped_oscillation.rates), + ) + + if len({len_labels, len_frequencies, len_rates}) > 1: + issues.append( + OscillationParameterIssue( + damped_oscillation.label, len_labels, len_frequencies, len_rates ) + ) - return problems + return issues + + +@megacomplex(dataset_model_type=DecayDatasetModel) +class DampedOscillationMegacomplex(Megacomplex): + dimension: str = "time" + type: str = "damped-oscillation" + labels: list[str] = attribute(validator=validate_oscillation_parameter) + frequencies: list[ParameterType] + rates: list[ParameterType] def calculate_matrix( self, diff --git a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py index 04b7050df..236045deb 100755 --- a/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py +++ b/glotaran/builtin/megacomplexes/damped_oscillation/test/test_doas_model.py @@ -6,43 +6,23 @@ from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex -from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate - -class DampedOscillationsModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ): - defaults: dict[str, type[Megacomplex]] = { - "damped_oscillation": DampedOscillationMegacomplex, - "decay": DecayMegacomplex, - "spectral": SpectralMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) +DampedOscillationsModel = Model.create_class_from_megacomplexes( + [DampedOscillationMegacomplex, DecayMegacomplex, SpectralMegacomplex] +) class OneOscillation: - sim_model = DampedOscillationsModel.from_dict( - { + sim_model = DampedOscillationsModel( + **{ "megacomplex": { "m1": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -64,11 +44,11 @@ class OneOscillation: } ) - model = DampedOscillationsModel.from_dict( - { + model = DampedOscillationsModel( + **{ "megacomplex": { "m1": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -78,7 +58,7 @@ class OneOscillation: } ) - wanted_parameter = ParameterGroup.from_dict( + wanted_parameter = Parameters.from_dict( { "osc": [ ["freq", 25.5], @@ -88,7 +68,7 @@ class OneOscillation: } ) - parameter = ParameterGroup.from_dict( + parameter = Parameters.from_dict( { "osc": [ ["freq", 20], @@ -106,11 +86,11 @@ class OneOscillation: class OneOscillationWithIrf: - sim_model = DampedOscillationsModel.from_dict( - { + sim_model = DampedOscillationsModel( + **{ "megacomplex": { "m1": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -145,11 +125,11 @@ class OneOscillationWithIrf: } ) - model = DampedOscillationsModel.from_dict( - { + model = DampedOscillationsModel( + **{ "megacomplex": { "m1": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -171,7 +151,7 @@ class OneOscillationWithIrf: } ) - wanted_parameter = ParameterGroup.from_dict( + wanted_parameter = Parameters.from_dict( { "osc": [ ["freq", 25], @@ -182,7 +162,7 @@ class OneOscillationWithIrf: } ) - parameter = ParameterGroup.from_dict( + parameter = Parameters.from_dict( { "osc": [ ["freq", 25], @@ -201,8 +181,8 @@ class OneOscillationWithIrf: class OneOscillationWithSequentialModel: - sim_model = DampedOscillationsModel.from_dict( - { + sim_model = DampedOscillationsModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, }, @@ -217,7 +197,7 @@ class OneOscillationWithSequentialModel: "megacomplex": { "m1": {"type": "decay", "k_matrix": ["k1"]}, "m2": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -270,8 +250,8 @@ class OneOscillationWithSequentialModel: } ) - model = DampedOscillationsModel.from_dict( - { + model = DampedOscillationsModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, }, @@ -286,7 +266,7 @@ class OneOscillationWithSequentialModel: "megacomplex": { "m1": {"type": "decay", "k_matrix": ["k1"]}, "m2": { - "type": "damped_oscillation", + "type": "damped-oscillation", "labels": ["osc1"], "frequencies": ["osc.freq"], "rates": ["osc.rate"], @@ -309,7 +289,7 @@ class OneOscillationWithSequentialModel: } ) - wanted_parameter = ParameterGroup.from_dict( + wanted_parameter = Parameters.from_dict( { "j": [ ["1", 1, {"vary": False, "non-negative": False}], @@ -328,7 +308,7 @@ class OneOscillationWithSequentialModel: } ) - parameter = ParameterGroup.from_dict( + parameter = Parameters.from_dict( { "j": [ ["1", 1, {"vary": False, "non-negative": False}], @@ -394,8 +374,8 @@ def test_doas_model(suite): result = optimize(scheme, raise_exception=True) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, suite.wanted_parameter.get(label).value, rtol=1e-1) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, suite.wanted_parameter.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] assert np.array_equal(dataset["time"], resultdata["time"]) diff --git a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py index 8e420c33a..1abadc98d 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_megacomplex.py @@ -1,8 +1,5 @@ -"""This package contains the decay megacomplex item.""" from __future__ import annotations -from typing import List - import numpy as np import xarray as xr @@ -15,23 +12,22 @@ from glotaran.model import DatasetModel from glotaran.model import Megacomplex from glotaran.model import ModelError +from glotaran.model import ModelItemType +from glotaran.model import item from glotaran.model import megacomplex -@megacomplex( - dimension="time", - model_items={ - "k_matrix": List[KMatrix], - }, - properties={}, - dataset_model_items={ - "initial_concentration": {"type": InitialConcentration, "allow_none": True}, - "irf": {"type": Irf, "allow_none": True}, - }, - register_as="decay", -) +@item +class DecayDatasetModel(DatasetModel): + initial_concentration: ModelItemType[InitialConcentration] | None = None + irf: ModelItemType[Irf] | None = None + + +@megacomplex(dataset_model_type=DecayDatasetModel) class DecayMegacomplex(Megacomplex): - """A Megacomplex with one or more K-Matrices.""" + dimension: str = "time" + type: str = "decay" + k_matrix: list[ModelItemType[KMatrix]] def get_compartments(self, dataset_model: DatasetModel) -> list[str]: if dataset_model.initial_concentration is None: diff --git a/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py index 85e202afe..fddfde848 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_parallel_megacomplex.py @@ -1,8 +1,6 @@ """This package contains the decay megacomplex item.""" from __future__ import annotations -from typing import List - import numpy as np import xarray as xr @@ -13,22 +11,24 @@ from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.model import DatasetModel from glotaran.model import Megacomplex +from glotaran.model import ModelItemType +from glotaran.model import ParameterType +from glotaran.model import item from glotaran.model import megacomplex -from glotaran.parameter import Parameter -@megacomplex( - dimension="time", - properties={ - "compartments": List[str], - "rates": List[Parameter], - }, - dataset_model_items={ - "irf": {"type": Irf, "allow_none": True}, - }, - register_as="decay-parallel", -) +@item +class DecayDatasetModel(DatasetModel): + irf: ModelItemType[Irf] | None = None + + +@megacomplex(dataset_model_type=DecayDatasetModel) class DecayParallelMegacomplex(Megacomplex): + dimension: str = "time" + type: str = "decay-parallel" + compartments: list[str] + rates: list[ParameterType] + def get_compartments(self, dataset_model: DatasetModel) -> list[str]: return self.compartments @@ -41,12 +41,13 @@ def get_initial_concentration( return initial_concentration def get_k_matrix(self) -> KMatrix: - size = len(self.compartments) - k_matrix = KMatrix() - k_matrix.matrix = { - (self.compartments[i], self.compartments[i]): self.rates[i] for i in range(size) - } - return k_matrix + return KMatrix( + label="", + matrix={ + (self.compartments[i], self.compartments[i]): self.rates[i] + for i in range(len(self.compartments)) + }, + ) def get_a_matrix(self, dataset_model: DatasetModel) -> np.ndarray: return self.get_k_matrix().a_matrix_general( @@ -58,7 +59,7 @@ def index_dependent(self, dataset_model: DatasetModel) -> bool: def calculate_matrix( self, - dataset_model: DatasetModel, + dataset_model: DecayDatasetModel, global_index: int | None, global_axis: np.typing.ArrayLike, model_axis: np.typing.ArrayLike, diff --git a/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py b/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py index 423d02471..7534dd3ba 100644 --- a/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/decay_sequential_megacomplex.py @@ -1,36 +1,25 @@ """This package contains the decay megacomplex item.""" from __future__ import annotations -from typing import List - import numpy as np import xarray as xr -from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import DecayDatasetModel from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix from glotaran.builtin.megacomplexes.decay.util import calculate_matrix from glotaran.builtin.megacomplexes.decay.util import finalize_data from glotaran.builtin.megacomplexes.decay.util import index_dependent from glotaran.model import DatasetModel -from glotaran.model import Megacomplex from glotaran.model import megacomplex -from glotaran.parameter import Parameter -@megacomplex( - dimension="time", - properties={ - "compartments": List[str], - "rates": List[Parameter], - }, - dataset_model_items={ - "irf": {"type": Irf, "allow_none": True}, - }, - register_as="decay-sequential", -) -class DecaySequentialMegacomplex(Megacomplex): +@megacomplex(dataset_model_type=DecayDatasetModel) +class DecaySequentialMegacomplex(DecayParallelMegacomplex): """A Megacomplex with one or more K-Matrices.""" + type: str = "decay-sequential" + def get_compartments(self, dataset_model: DatasetModel) -> list[str]: return self.compartments @@ -42,12 +31,13 @@ def get_initial_concentration( return initial_concentration def get_k_matrix(self) -> KMatrix: - size = len(self.compartments) - k_matrix = KMatrix() - k_matrix.matrix = { - (self.compartments[i + 1], self.compartments[i]): self.rates[i] - for i in range(size - 1) - } + k_matrix = KMatrix( + label="", + matrix={ + (self.compartments[i + 1], self.compartments[i]): self.rates[i] + for i in range(len(self.compartments) - 1) + }, + ) k_matrix.matrix[self.compartments[-1], self.compartments[-1]] = self.rates[-1] return k_matrix diff --git a/glotaran/builtin/megacomplexes/decay/initial_concentration.py b/glotaran/builtin/megacomplexes/decay/initial_concentration.py index d6282f062..40ec5148e 100644 --- a/glotaran/builtin/megacomplexes/decay/initial_concentration.py +++ b/glotaran/builtin/megacomplexes/decay/initial_concentration.py @@ -1,25 +1,22 @@ """This package contains the initial concentration item.""" from __future__ import annotations -from typing import List - import numpy as np -from glotaran.model import model_item -from glotaran.parameter import Parameter +from glotaran.model import ModelItem +from glotaran.model import ParameterType +from glotaran.model import item -@model_item( - properties={ - "compartments": List[str], - "parameters": List[Parameter], - "exclude_from_normalize": {"type": List[str], "default": []}, - } -) -class InitialConcentration: +@item +class InitialConcentration(ModelItem): """An initial concentration describes the population of the compartments at the beginning of an experiment.""" + compartments: list[str] + parameters: list[ParameterType] + exclude_from_normalize: list[str] = [] + def normalized(self) -> np.ndarray: normalized = np.array(self.parameters) idx = [c not in self.exclude_from_normalize for c in self.compartments] diff --git a/glotaran/builtin/megacomplexes/decay/irf.py b/glotaran/builtin/megacomplexes/decay/irf.py index 9e0aff7ff..10107992e 100644 --- a/glotaran/builtin/megacomplexes/decay/irf.py +++ b/glotaran/builtin/megacomplexes/decay/irf.py @@ -1,34 +1,22 @@ """This package contains irf items.""" -from typing import List -from typing import Tuple import numpy as np from glotaran.model import ModelError -from glotaran.model import model_item -from glotaran.model import model_item_typed -from glotaran.parameter import Parameter - - -@model_item(has_type=True) -class IrfMeasured: - """A measured IRF. The data must be supplied by the dataset.""" - - -@model_item( - properties={ - "center": List[Parameter], - "width": List[Parameter], - "scale": {"type": List[Parameter], "allow_none": True}, - "shift": {"type": List[Parameter], "allow_none": True}, - "normalize": {"type": bool, "default": True}, - "backsweep": {"type": bool, "default": False}, - "backsweep_period": {"type": Parameter, "allow_none": True}, - }, - has_type=True, -) -class IrfMultiGaussian: +from glotaran.model import ModelItemTyped +from glotaran.model import ParameterType +from glotaran.model import attribute +from glotaran.model import item + + +@item +class Irf(ModelItemTyped): + """Represents an IRF.""" + + +@item +class IrfMultiGaussian(Irf): """ Represents a gaussian IRF. @@ -56,9 +44,19 @@ class IrfMultiGaussian: """ + type: str = "multi-gaussian" + + center: list[ParameterType] + width: list[ParameterType] + scale: list[ParameterType] | None = None + shift: list[ParameterType] | None = None + normalize: bool = True + backsweep: bool = False + backsweep_period: ParameterType | None = None + def parameter( self, global_index: int, global_axis: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, bool, float]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float, bool, float]: """Returns the properties of the irf with shift applied.""" centers = self.center if isinstance(self.center, list) else [self.center] @@ -110,26 +108,14 @@ def is_index_dependent(self): return self.shift is not None -@model_item( - properties={ - "center": Parameter, - "width": Parameter, - }, - has_type=True, -) +@item class IrfGaussian(IrfMultiGaussian): - pass - - -@model_item( - properties={ - "dispersion_center": {"type": Parameter, "allow_none": True}, - "center_dispersion_coefficients": {"type": List[Parameter], "default": []}, - "width_dispersion_coefficients": {"type": List[Parameter], "default": []}, - "model_dispersion_with_wavenumber": {"type": bool, "default": False}, - }, - has_type=True, -) + type: str = "gaussian" + center: ParameterType + width: ParameterType + + +@item class IrfSpectralMultiGaussian(IrfMultiGaussian): """ Represents a gaussian IRF. @@ -158,6 +144,12 @@ class IrfSpectralMultiGaussian(IrfMultiGaussian): """ + type: str = "spectral-multi-gaussian" + dispersion_center: ParameterType + center_dispersion_coefficients: list[ParameterType] + width_dispersion_coefficients: list[ParameterType] = attribute(factory=list) + model_dispersion_with_wavenumber: bool = False + def parameter(self, global_index: int, global_axis: np.ndarray): """Returns the properties of the irf with shift and dispersion applied.""" centers, widths, scale, shift, backsweep, backsweep_period = super().parameter( @@ -198,24 +190,8 @@ def is_index_dependent(self): return super().is_index_dependent() or self.dispersion_center is not None -@model_item( - properties={ - "center": Parameter, - "width": Parameter, - }, - has_type=True, -) +@item class IrfSpectralGaussian(IrfSpectralMultiGaussian): - pass - - -@model_item_typed( - types={ - "gaussian": IrfGaussian, - "multi-gaussian": IrfMultiGaussian, - "spectral-multi-gaussian": IrfSpectralMultiGaussian, - "spectral-gaussian": IrfSpectralGaussian, - } -) -class Irf: - """Represents an IRF.""" + type: str = "spectral-gaussian" + center: ParameterType + width: ParameterType diff --git a/glotaran/builtin/megacomplexes/decay/k_matrix.py b/glotaran/builtin/megacomplexes/decay/k_matrix.py index b06f05b1e..98882745a 100644 --- a/glotaran/builtin/megacomplexes/decay/k_matrix.py +++ b/glotaran/builtin/megacomplexes/decay/k_matrix.py @@ -2,15 +2,15 @@ from __future__ import annotations import itertools -import typing from collections import OrderedDict import numpy as np from scipy.linalg import eig from scipy.linalg import solve -from glotaran.model import model_item -from glotaran.parameter import Parameter +from glotaran.model import ModelItem +from glotaran.model import ParameterType +from glotaran.model import item from glotaran.utils.ipython import MarkdownStr @@ -18,14 +18,12 @@ def calculate_gamma(eigenvectors: np.ndarray, initial_concentration: np.ndarray) return np.diag(solve(eigenvectors, initial_concentration)) -@model_item( - properties={ - "matrix": {"type": typing.Dict[typing.Tuple[str, str], Parameter]}, - }, -) -class KMatrix: +@item +class KMatrix(ModelItem): """A K-Matrix represents a first order differental system.""" + matrix: dict[tuple[str, str], ParameterType] + @classmethod def empty(cls, label: str, compartments: list[str]) -> KMatrix: """Creates an empty K-Matrix. Useful for combining. @@ -76,10 +74,7 @@ def combine(self, k_matrix: KMatrix) -> KMatrix: combined_matrix = {entry: self.matrix[entry] for entry in self.matrix} for entry in k_matrix.matrix: combined_matrix[entry] = k_matrix.matrix[entry] - combined = KMatrix() - combined.label = f"{self.label}+{k_matrix.label}" - combined.matrix = combined_matrix - return combined + return KMatrix(label=f"{self.label}+{k_matrix.label}", matrix=combined_matrix) def matrix_as_markdown( self, @@ -105,9 +100,7 @@ def matrix_as_markdown( for index in self.matrix: i = compartments.index(index[0]) j = compartments.index(index[1]) - array[i, j] = ( - self.matrix[index].value if fill_parameters else self.matrix[index].full_label - ) + array[i, j] = self.matrix[index].value if fill_parameters else self.matrix[index] return self._array_as_markdown(array, compartments, compartments) diff --git a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py index 702b4a505..3251d7763 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_decay_megacomplex.py @@ -4,9 +4,13 @@ import pytest import xarray as xr +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex from glotaran.model import Model +from glotaran.model.item import fill_item from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -21,24 +25,20 @@ def create_gaussian_clp(labels, amplitudes, centers, widths, axis): ).T -class DecayModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - ): - model_dict = {**model_dict, "default_megacomplex": "decay"} - return super().from_dict(model_dict) +DecaySimpleModel = Model.create_class_from_megacomplexes( + [DecayParallelMegacomplex, DecaySequentialMegacomplex] +) +DecayModel = Model.create_class_from_megacomplexes([DecayMegacomplex]) class OneComponentOneChannel: - model = DecayModel.from_dict( - { + model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -56,12 +56,10 @@ class OneComponentOneChannel: } ) - initial_parameters = ParameterGroup.from_list( + initial_parameters = Parameters.from_list( [101e-4, [1, {"vary": False, "non-negative": False}]] ) - wanted_parameters = ParameterGroup.from_list( - [101e-3, [1, {"vary": False, "non-negative": False}]] - ) + wanted_parameters = Parameters.from_list([101e-3, [1, {"vary": False, "non-negative": False}]]) time = np.arange(0, 50, 1.5) pixel = np.asarray([0]) @@ -71,13 +69,13 @@ class OneComponentOneChannel: class OneComponentOneChannelGaussianIrf: - model = DecayModel.from_dict( - { + model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["5"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -99,13 +97,20 @@ class OneComponentOneChannelGaussianIrf: } ) - initial_parameters = ParameterGroup.from_list( - [101e-4, 0.1, 1, [0.1, {"vary": False}], [1, {"vary": False, "non-negative": False}]] + initial_parameters = Parameters.from_list( + [ + 101e-4, + 0.1, + 1, + [0.1, {"vary": False}], + [1, {"vary": False, "non-negative": False}], + ] ) + print(initial_parameters) assert model.megacomplex["mc1"].index_dependent( - model.dataset["dataset1"].fill(model, initial_parameters) + fill_item(model.dataset["dataset1"], model, initial_parameters) ) - wanted_parameters = ParameterGroup.from_list( + wanted_parameters = Parameters.from_list( [ [101e-3, {"non-negative": True}], [0.2, {"non-negative": True}], @@ -123,8 +128,8 @@ class OneComponentOneChannelGaussianIrf: class ThreeComponentParallel: - model = DecayModel.from_dict( - { + model = DecaySimpleModel( + **{ "megacomplex": { "mc1": { "type": "decay-parallel", @@ -152,7 +157,7 @@ class ThreeComponentParallel: } ) - initial_parameters = ParameterGroup.from_dict( + initial_parameters = Parameters.from_dict( { "kinetic": [ ["1", 501e-3], @@ -163,7 +168,7 @@ class ThreeComponentParallel: "irf": [["center", 1.3], ["width", 7.8]], } ) - wanted_parameters = ParameterGroup.from_dict( + wanted_parameters = Parameters.from_dict( { "kinetic": [ ["1", 501e-3], @@ -183,8 +188,8 @@ class ThreeComponentParallel: class ThreeComponentSequential: - model = DecayModel.from_dict( - { + model = DecaySimpleModel( + **{ "megacomplex": { "mc1": { "type": "decay-sequential", @@ -212,7 +217,7 @@ class ThreeComponentSequential: } ) - initial_parameters = ParameterGroup.from_dict( + initial_parameters = Parameters.from_dict( { "kinetic": [ ["1", 501e-3], @@ -223,7 +228,7 @@ class ThreeComponentSequential: "irf": [["center", 1.3], ["width", 7.8]], } ) - wanted_parameters = ParameterGroup.from_dict( + wanted_parameters = Parameters.from_dict( { "kinetic": [ ["1", 501e-3], @@ -256,7 +261,7 @@ def test_kinetic_model(suite, nnls): model = suite.model print(model.validate()) assert model.valid() - model.dataset_group_models["default"].method = ( + model.dataset_groups["default"].method = ( "non_negative_least_squares" if nnls else "variable_projection" ) @@ -286,8 +291,8 @@ def test_kinetic_model(suite, nnls): result = optimize(scheme) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] assert np.array_equal(dataset["time"], resultdata["time"]) @@ -307,14 +312,14 @@ def test_kinetic_model(suite, nnls): def test_finalize_data(): - model = DecayModel.from_dict( - { + model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1", "s2"], "parameters": ["3", "3"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, - "mc2": {"k_matrix": ["k2"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, + "mc2": {"type": "decay", "k_matrix": ["k2"]}, }, "k_matrix": { "k1": { @@ -337,7 +342,7 @@ def test_finalize_data(): } ) - parameters = ParameterGroup.from_list( + parameters = Parameters.from_list( [101e-4, 101e-3, [1, {"vary": False, "non-negative": False}]] ) diff --git a/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py b/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py index 48a5d3581..9d2ede48c 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_k_matrix.py @@ -4,7 +4,8 @@ from glotaran.builtin.megacomplexes.decay.k_matrix import KMatrix from glotaran.builtin.megacomplexes.decay.k_matrix import calculate_gamma -from glotaran.parameter import ParameterGroup +from glotaran.model.item import fill_item +from glotaran.parameter import Parameters class SequentialModel: @@ -184,12 +185,10 @@ class ParallelModelWithEquilibria: ) def test_a_matrix_general(matrix): - params = ParameterGroup.from_list(matrix.params) + params = Parameters.from_list(matrix.params) - mat = KMatrix() - mat.label = "" - mat.matrix = matrix.matrix - mat = mat.fill(None, params) + mat = KMatrix(label="", matrix=matrix.matrix) + mat = fill_item(mat, None, params) initial_concentration = matrix.jvec @@ -228,11 +227,9 @@ def test_a_matrix_sequential(): ("s3", "s3"): "3", } - params = ParameterGroup.from_list([3, 4, 5]) - mat = KMatrix() - mat.label = "" - mat.matrix = matrix - mat = mat.fill(None, params) + params = Parameters.from_list([3, 4, 5]) + mat = KMatrix(label="", matrix=matrix) + mat = fill_item(mat, None, params) initial_concentration = [1, 0, 0] @@ -244,11 +241,9 @@ def test_a_matrix_sequential(): } compartments = ["s1", "s2"] - params = ParameterGroup.from_list([0.55, 0.0404]) - mat = KMatrix() - mat.label = "" - mat.matrix = matrix - mat = mat.fill(None, params) + params = Parameters.from_list([0.55, 0.0404]) + mat = KMatrix(label="", matrix=matrix) + mat = fill_item(mat, None, params) initial_concentration = [1, 0] @@ -272,24 +267,19 @@ def test_combine_matrices(): ("s1", "s1"): "1", ("s2", "s2"): "2", } - mat1 = KMatrix() - mat1.label = "A" - mat1.matrix = matrix1 - + mat1 = KMatrix(label="A", matrix=matrix1) matrix2 = { ("s2", "s2"): "3", ("s3", "s3"): "4", } - mat2 = KMatrix() - mat2.label = "B" - mat2.matrix = matrix2 + mat2 = KMatrix(label="B", matrix=matrix2) combined = mat1.combine(mat2) assert combined.label == "A+B" - assert combined.matrix[("s1", "s1")].full_label == "1" - assert combined.matrix[("s2", "s2")].full_label == "3" - assert combined.matrix[("s3", "s3")].full_label == "4" + assert combined.matrix[("s1", "s1")] == "1" + assert combined.matrix[("s2", "s2")] == "3" + assert combined.matrix[("s3", "s3")] == "4" def test_kmatrix_ipython_rendering(): @@ -299,9 +289,7 @@ def test_kmatrix_ipython_rendering(): ("s1", "s1"): "1", ("s2", "s2"): "2", } - kmatrix = KMatrix() - kmatrix.label = "A" - kmatrix.matrix = matrix + kmatrix = KMatrix(label="A", matrix=matrix) rendered_obj = format_display_data(kmatrix)[0] diff --git a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py index 913631fe8..324d14bb0 100644 --- a/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py +++ b/glotaran/builtin/megacomplexes/decay/test/test_spectral_irf.py @@ -1,12 +1,13 @@ import warnings -from copy import deepcopy from textwrap import dedent import numpy as np import pytest +from attrs import evolve from glotaran.io import load_model from glotaran.io import load_parameters +from glotaran.model import fill_item from glotaran.optimization.optimize import optimize from glotaran.project import Scheme from glotaran.simulation import simulate @@ -41,7 +42,7 @@ {MODEL_BASE} irf: irf1: - type: spectral-gaussian + type: gaussian center: irf.center width: irf.width """ @@ -190,8 +191,9 @@ def test_spectral_irf(suite): parameters = suite.parameters assert model.valid(parameters), model.validate(parameters) - sim_model = deepcopy(model) + sim_model = evolve(model) sim_model.dataset["dataset1"].global_megacomplex = ["mc2"] + print(sim_model) dataset = simulate(sim_model, "dataset1", parameters, suite.axis) assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) @@ -206,13 +208,8 @@ def test_spectral_irf(suite): ) result = optimize(scheme) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, parameters.get(label).value), dedent( - f""" - Error in {suite.__name__} comparing {param.full_label}, - - diff={param.value-parameters.get(label).value} - """ - ) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] @@ -246,11 +243,10 @@ def test_spectral_irf(suite): for x in suite.axis["spectral"]: # calculated irf location - model_irf_center = suite.model.irf["irf1"].center - model_dispersion_center = suite.model.irf["irf1"].dispersion_center - model_center_dispersion_coefficients = suite.model.irf[ - "irf1" - ].center_dispersion_coefficients + irf = fill_item(suite.model.irf["irf1"], suite.model, result.optimized_parameters) + model_irf_center = irf.center + model_dispersion_center = irf.dispersion_center + model_center_dispersion_coefficients = irf.center_dispersion_coefficients calc_irf_location_at_x = _calculate_irf_position( x, model_irf_center, model_dispersion_center, model_center_dispersion_coefficients ) diff --git a/glotaran/builtin/megacomplexes/decay/util.py b/glotaran/builtin/megacomplexes/decay/util.py index 16e028478..809e5dc22 100644 --- a/glotaran/builtin/megacomplexes/decay/util.py +++ b/glotaran/builtin/megacomplexes/decay/util.py @@ -8,6 +8,7 @@ from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian from glotaran.model import DatasetModel from glotaran.model import Megacomplex +from glotaran.model import get_dataset_model_model_dimension def index_dependent(dataset_model: DatasetModel) -> bool: @@ -233,7 +234,7 @@ def retrieve_species_associated_data( is_full_model: bool, as_global: bool, ): - model_dimension = dataset_model.get_model_dimension() + model_dimension = get_dataset_model_model_dimension(dataset_model) if as_global: model_dimension, global_dimension = global_dimension, model_dimension dataset.coords[species_dimension] = species @@ -355,7 +356,7 @@ def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimens return irf = dataset_model.irf - model_dimension = dataset_model.get_model_dimension() + model_dimension = get_dataset_model_model_dimension(dataset_model) dataset["irf"] = ( (model_dimension), diff --git a/glotaran/builtin/megacomplexes/spectral/shape.py b/glotaran/builtin/megacomplexes/spectral/shape.py index f727b9ea6..5ede00bab 100644 --- a/glotaran/builtin/megacomplexes/spectral/shape.py +++ b/glotaran/builtin/megacomplexes/spectral/shape.py @@ -2,22 +2,25 @@ import numpy as np -from glotaran.model import model_item -from glotaran.model import model_item_typed -from glotaran.parameter import Parameter - - -@model_item( - properties={ - "amplitude": {"type": Parameter, "allow_none": True}, - "location": Parameter, - "width": Parameter, - }, - has_type=True, -) -class SpectralShapeGaussian: +from glotaran.model import ModelItemTyped +from glotaran.model import ParameterType +from glotaran.model import item + + +@item +class SpectralShape(ModelItemTyped): + pass + + +@item +class SpectralShapeGaussian(SpectralShape): """A Gaussian spectral shape""" + type: str = "gaussian" + amplitude: ParameterType | None = None + location: ParameterType + width: ParameterType + def calculate(self, axis: np.ndarray) -> np.ndarray: r"""Calculate a normal Gaussian shape for a given ``axis``. @@ -61,15 +64,13 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: return shape -@model_item( - properties={ - "skewness": Parameter, - }, - has_type=True, -) +@item class SpectralShapeSkewedGaussian(SpectralShapeGaussian): """A skewed Gaussian spectral shape""" + type: str = "skewed-gaussian" + skewness: ParameterType + def calculate(self, axis: np.ndarray) -> np.ndarray: r"""Calculate the skewed Gaussian shape for ``axis``. @@ -135,10 +136,12 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: return shape -@model_item(properties={}, has_type=True) -class SpectralShapeOne: +@item +class SpectralShapeOne(SpectralShape): """A constant spectral shape with value 1""" + type: str = "one" + def calculate(self, axis: np.ndarray) -> np.ndarray: """calculate calculates the shape. @@ -155,10 +158,12 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: return np.ones(axis.shape[0]) -@model_item(properties={}, has_type=True) -class SpectralShapeZero: +@item +class SpectralShapeZero(SpectralShape): """A constant spectral shape with value 0""" + type: str = "zero" + def calculate(self, axis: np.ndarray) -> np.ndarray: """calculate calculates the shape. @@ -175,15 +180,3 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: """ return np.zeros(axis.shape[0]) - - -@model_item_typed( - types={ - "gaussian": SpectralShapeGaussian, - "skewed-gaussian": SpectralShapeSkewedGaussian, - "one": SpectralShapeOne, - "zero": SpectralShapeZero, - } -) -class SpectralShape: - """Base class for spectral shapes""" diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index e5d6de9e7..754b589b0 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Dict - import numpy as np import xarray as xr @@ -9,21 +7,23 @@ from glotaran.model import DatasetModel from glotaran.model import Megacomplex from glotaran.model import ModelError +from glotaran.model import ModelItemType +from glotaran.model import item from glotaran.model import megacomplex -@megacomplex( - dimension="spectral", - dataset_properties={ - "spectral_axis_inverted": {"type": bool, "default": False}, - "spectral_axis_scale": {"type": float, "default": 1}, - }, - model_items={ - "shape": Dict[str, SpectralShape], - }, - register_as="spectral", -) +@item +class SpectralDatasetModel(DatasetModel): + spectral_axis_inverted: bool = False + spectral_axis_scale: float = 1 + + +@megacomplex(dataset_model_type=SpectralDatasetModel) class SpectralMegacomplex(Megacomplex): + dimension: str = "spectral" + type: str = "spectral" + shape: dict[str, ModelItemType[SpectralShape]] + def calculate_matrix( self, dataset_model: DatasetModel, diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index 2871732db..18067a0a2 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -6,44 +6,25 @@ from glotaran.builtin.megacomplexes.decay.test.test_decay_megacomplex import DecayModel from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex -from glotaran.model import Megacomplex from glotaran.model import Model +from glotaran.model import fill_item from glotaran.optimization.matrix_provider import MatrixProvider from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate - -class SpectralModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ): - defaults: dict[str, type[Megacomplex]] = { - "spectral": SpectralMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) +SpectralModel = Model.create_class_from_megacomplexes([SpectralMegacomplex]) class OneCompartmentModelInvertedAxis: - decay_model = DecayModel.from_dict( - { + decay_model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -61,14 +42,12 @@ class OneCompartmentModelInvertedAxis: } ) - decay_parameters = ParameterGroup.from_list( - [101e-4, [1, {"vary": False, "non-negative": False}]] - ) + decay_parameters = Parameters.from_list([101e-4, [1, {"vary": False, "non-negative": False}]]) - spectral_model = SpectralModel.from_dict( - { + spectral_model = SpectralModel( + **{ "megacomplex": { - "mc1": {"shape": {"s1": "sh1"}}, + "mc1": {"type": "spectral", "shape": {"s1": "sh1"}}, }, "shape": { "sh1": { @@ -88,26 +67,26 @@ class OneCompartmentModelInvertedAxis: } ) - spectral_parameters = ParameterGroup.from_list([7, 1e7 / 10000, 800, -1]) + spectral_parameters = Parameters.from_list([7, 1e7 / 10000, 800, -1]) time = np.arange(-10, 50, 1.5) spectral = np.arange(5000, 15000, 20) axis = {"time": time, "spectral": spectral} - decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model = fill_item(decay_model.dataset["dataset1"], decay_model, decay_parameters) matrix = MatrixProvider.calculate_dataset_matrix(decay_dataset_model, None, spectral, time) decay_compartments = matrix.clp_labels clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) class OneCompartmentModelNegativeSkew: - decay_model = DecayModel.from_dict( - { + decay_model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1"], "parameters": ["2"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -125,14 +104,12 @@ class OneCompartmentModelNegativeSkew: } ) - decay_parameters = ParameterGroup.from_list( - [101e-4, [1, {"vary": False, "non-negative": False}]] - ) + decay_parameters = Parameters.from_list([101e-4, [1, {"vary": False, "non-negative": False}]]) - spectral_model = SpectralModel.from_dict( - { + spectral_model = SpectralModel( + **{ "megacomplex": { - "mc1": {"shape": {"s1": "sh1"}}, + "mc1": {"type": "spectral", "shape": {"s1": "sh1"}}, }, "shape": { "sh1": { @@ -148,34 +125,34 @@ class OneCompartmentModelNegativeSkew: } ) - spectral_parameters = ParameterGroup.from_list([1000, 80, -1]) + spectral_parameters = Parameters.from_list([1000, 80, -1]) time = np.arange(-10, 50, 1.5) spectral = np.arange(400, 600, 5) axis = {"time": time, "spectral": spectral} - decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model = fill_item(decay_model.dataset["dataset1"], decay_model, decay_parameters) matrix = MatrixProvider.calculate_dataset_matrix(decay_dataset_model, None, spectral, time) decay_compartments = matrix.clp_labels clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) class OneCompartmentModelPositivSkew(OneCompartmentModelNegativeSkew): - spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 1]) + spectral_parameters = Parameters.from_list([7, 20000, 800, 1]) class OneCompartmentModelZeroSkew(OneCompartmentModelNegativeSkew): - spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 0]) + spectral_parameters = Parameters.from_list([7, 20000, 800, 0]) class ThreeCompartmentModel: - decay_model = DecayModel.from_dict( - { + decay_model = DecayModel( + **{ "initial_concentration": { "j1": {"compartments": ["s1", "s2", "s3"], "parameters": ["4", "4", "4"]}, }, "megacomplex": { - "mc1": {"k_matrix": ["k1"]}, + "mc1": {"type": "decay", "k_matrix": ["k1"]}, }, "k_matrix": { "k1": { @@ -195,19 +172,20 @@ class ThreeCompartmentModel: } ) - decay_parameters = ParameterGroup.from_list( + decay_parameters = Parameters.from_list( [101e-4, 101e-5, 101e-6, [1, {"vary": False, "non-negative": False}]] ) - spectral_model = SpectralModel.from_dict( - { + spectral_model = SpectralModel( + **{ "megacomplex": { "mc1": { + "type": "spectral", "shape": { "s1": "sh1", "s2": "sh2", "s3": "sh3", - } + }, }, }, "shape": { @@ -238,7 +216,7 @@ class ThreeCompartmentModel: } ) - spectral_parameters = ParameterGroup.from_list( + spectral_parameters = Parameters.from_list( [ 7, 450, @@ -256,7 +234,7 @@ class ThreeCompartmentModel: spectral = np.arange(400, 600, 5) axis = {"time": time, "spectral": spectral} - decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model = fill_item(decay_model.dataset["dataset1"], decay_model, decay_parameters) matrix = MatrixProvider.calculate_dataset_matrix(decay_dataset_model, None, spectral, time) decay_compartments = matrix.clp_labels clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) @@ -303,8 +281,8 @@ def test_spectral_model(suite): result = optimize(scheme) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] assert np.array_equal(dataset["time"], resultdata["time"]) diff --git a/glotaran/builtin/megacomplexes/test/test_spectral_decay.py b/glotaran/builtin/megacomplexes/test/test_spectral_decay.py index 7183c27e8..c4e36116a 100644 --- a/glotaran/builtin/megacomplexes/test/test_spectral_decay.py +++ b/glotaran/builtin/megacomplexes/test/test_spectral_decay.py @@ -70,7 +70,7 @@ s3: sh3 irf: irf1: - type: spectral-multi-gaussian + type: multi-gaussian center: [irf.center] width: [irf.width] shape: @@ -245,8 +245,8 @@ def test_decay_model(suite, nnls): model = suite.model print(model.validate()) assert model.valid() - model.dataset_group_models["default"].link_clp = False - model.dataset_group_models["default"].method = ( + model.dataset_groups["default"].link_clp = False + model.dataset_groups["default"].method = ( "non_negative_least_squares" if nnls else "variable_projection" ) @@ -276,8 +276,8 @@ def test_decay_model(suite, nnls): result = optimize(scheme) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - assert np.allclose(param.value, wanted_parameters.get(label).value) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] diff --git a/glotaran/builtin/megacomplexes/test/test_spectral_decay_full_model.py b/glotaran/builtin/megacomplexes/test/test_spectral_decay_full_model.py index 12737b9ca..e490c6c47 100644 --- a/glotaran/builtin/megacomplexes/test/test_spectral_decay_full_model.py +++ b/glotaran/builtin/megacomplexes/test/test_spectral_decay_full_model.py @@ -25,6 +25,7 @@ irf: irf1 dataset4: megacomplex: [mc2] + initial_concentration: j1 megacomplex: mc1: type: decay @@ -37,7 +38,7 @@ s3: sh3 irf: irf1: - type: spectral-multi-gaussian + type: multi-gaussian center: [irf.center] width: [irf.width] shape: @@ -180,7 +181,7 @@ def test_kinetic_model(suite, nnls): model = suite.model print(model.validate()) assert model.valid() - model.dataset_group_models["default"].method = ( + model.dataset_groups["default"].method = ( "non_negative_least_squares" if nnls else "variable_projection" ) @@ -210,9 +211,8 @@ def test_kinetic_model(suite, nnls): result = optimize(scheme) print(result.optimized_parameters) - for label, param in result.optimized_parameters.all(): - print(label, param.value, wanted_parameters.get(label).value) - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) resultdata = result.data["dataset1"] diff --git a/glotaran/builtin/megacomplexes/test/test_spectral_penalties.py b/glotaran/builtin/megacomplexes/test/test_spectral_penalties.py index 5f988024d..884dd4159 100644 --- a/glotaran/builtin/megacomplexes/test/test_spectral_penalties.py +++ b/glotaran/builtin/megacomplexes/test/test_spectral_penalties.py @@ -9,10 +9,9 @@ from glotaran.builtin.megacomplexes.decay import DecayMegacomplex from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.io import prepare_time_trace_dataset -from glotaran.model import Megacomplex from glotaran.model import Model from glotaran.optimization.optimize import optimize -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -26,26 +25,7 @@ OptimizationSpec = namedtuple("OptimizationSpec", "nnls max_nfev") -class SpectralDecayModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ): - defaults: dict[str, type[Megacomplex]] = { - "decay": DecayMegacomplex, - "spectral": SpectralMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) +SpectralDecayModel = Model.create_class_from_megacomplexes([DecayMegacomplex, SpectralMegacomplex]) def plot_overview(res, title=None): @@ -160,8 +140,9 @@ def test_equal_area_penalties(debug=False): } equ_area = { - "clp_area_penalties": [ + "clp_penalties": [ { + "type": "equal_area", "source": "s1", "target": "s2", "parameter": "rela.1", @@ -201,15 +182,15 @@ def test_equal_area_penalties(debug=False): mspec_fit_wp = dict(deepcopy(mspec.base), **mspec.equ_area) mspec_fit_np = dict(deepcopy(mspec.base)) - model_sim = SpectralDecayModel.from_dict(mspec_sim) - model_wp = SpectralDecayModel.from_dict(mspec_fit_wp) - model_np = SpectralDecayModel.from_dict(mspec_fit_np) + model_sim = SpectralDecayModel(**mspec_sim) + model_wp = SpectralDecayModel(**mspec_fit_wp) + model_np = SpectralDecayModel(**mspec_fit_np) print(model_np) # %% Parameter specification (pspec) pspec_sim = dict(deepcopy(pspec.base), **pspec.shapes) - param_sim = ParameterGroup.from_dict(pspec_sim) + param_sim = Parameters.from_dict(pspec_sim) # For the wp model we create two version of the parameter specification # One has all inputs fixed, the other has all but the first free @@ -219,8 +200,8 @@ def test_equal_area_penalties(debug=False): pspec_wp["i"] = [[1, {"vary": False}], 1] pspec_np = dict(deepcopy(pspec.base)) - param_wp = ParameterGroup.from_dict(pspec_wp) - param_np = ParameterGroup.from_dict(pspec_np) + param_wp = Parameters.from_dict(pspec_wp) + param_np = Parameters.from_dict(pspec_np) # %% Print models with parameters print(model_sim.markdown(param_sim)) @@ -244,7 +225,7 @@ def test_equal_area_penalties(debug=False): # %% Optimizing model without penalty (np) - model_np.dataset_group_models["default"].method = ( + model_np.dataset_groups["default"].method = ( "non_negative_least_squares" if optim_spec.nnls else "variable_projection" ) @@ -258,7 +239,7 @@ def test_equal_area_penalties(debug=False): result_np = optimize(scheme_np, raise_exception=True) print(result_np) - model_wp.dataset_group_models["default"].method = ( + model_wp.dataset_groups["default"].method = ( "non_negative_least_squares" if optim_spec.nnls else "variable_projection" ) diff --git a/glotaran/deprecation/modules/builtin_io_yml.py b/glotaran/deprecation/modules/builtin_io_yml.py index 7477af42d..5ac6b9776 100644 --- a/glotaran/deprecation/modules/builtin_io_yml.py +++ b/glotaran/deprecation/modules/builtin_io_yml.py @@ -113,6 +113,15 @@ def model_spec_deprecations(spec: MutableMapping[Any, Any]) -> None: stacklevel=load_model_stack_level, ) + deprecate_dict_entry( + dict_to_check=spec, + deprecated_usage="clp_area_penalties", + new_usage="clp_penalties", + to_be_removed_in_version="0.8.0", + swap_keys=("clp_area_penalties", "clp_penalties"), + stacklevel=load_model_stack_level, + ) + if "irf" in spec: for _, irf in spec["irf"].items(): deprecate_dict_entry( diff --git a/glotaran/deprecation/modules/test/test_builtin_io_yml.py b/glotaran/deprecation/modules/test/test_builtin_io_yml.py index 7182eafdd..e42c9b632 100644 --- a/glotaran/deprecation/modules/test/test_builtin_io_yml.py +++ b/glotaran/deprecation/modules/test/test_builtin_io_yml.py @@ -78,8 +78,19 @@ - type: equal_area """ ), + 2, + "clp_penalties", + [{"type": "equal_area"}], + ), + ( + dedent( + """ + clp_area_penalties: + - type: equal_area + """ + ), 1, - "clp_area_penalties", + "clp_penalties", [{"type": "equal_area"}], ), ( @@ -119,6 +130,7 @@ "spectral_constraints", "constraints", "equal_area_penalties", + "clp_area_penalties", "center_dispersion", "width_dispersion", ), diff --git a/glotaran/deprecation/modules/test/test_changed_imports.py b/glotaran/deprecation/modules/test/test_changed_imports.py index e0ae4c9a7..1a48ed836 100644 --- a/glotaran/deprecation/modules/test/test_changed_imports.py +++ b/glotaran/deprecation/modules/test/test_changed_imports.py @@ -61,7 +61,7 @@ def test_changed_import_test_warn_attribute_no_warn( recwarn: WarningsRecorder, ): """Module attribute import not warning""" - changed_import_test_warn(recwarn, "glotaran.parameter", attribute_name="ParameterGroup") + changed_import_test_warn(recwarn, "glotaran.parameter", attribute_name="Parameters") @pytest.mark.xfail(strict=True, reason="Fail if no warning") diff --git a/glotaran/deprecation/modules/test/test_model_model.py b/glotaran/deprecation/modules/test/test_model_model.py deleted file mode 100644 index 70c7ffdc6..000000000 --- a/glotaran/deprecation/modules/test/test_model_model.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for deprecated methods in ``glotaran.model.model``.""" -from __future__ import annotations - -import pytest - -from glotaran.deprecation.deprecation_utils import GlotaranDeprectedApiError -from glotaran.testing.simulated_data.parallel_spectral_decay import MODEL as dummy_model - - -def test_model_model_dimension(): - """Raise ``GlotaranApiDeprecationWarning``.""" - expected = ( - "Usage of 'Model.model_dimension' was deprecated, " - "use \"Scheme.model_dimensions['']\" instead.\n" - "It wasn't possible to restore the original behavior of this usage " - "(mostlikely due to an object hierarchy change)." - "This usage change message won't be show as of version: '0.7.0'." - ) - - with pytest.raises(GlotaranDeprectedApiError) as excinfo: - dummy_model.model_dimension - - assert str(excinfo.value) == expected - - -def test_model_global_dimension(): - """Raise ``GlotaranApiDeprecationWarning``.""" - expected = ( - "Usage of 'Model.global_dimension' was deprecated, " - "use \"Scheme.global_dimensions['']\" instead.\n" - "It wasn't possible to restore the original behavior of this usage " - "(mostlikely due to an object hierarchy change)." - "This usage change message won't be show as of version: '0.7.0'." - ) - - with pytest.raises(GlotaranDeprectedApiError) as excinfo: - dummy_model.global_dimension - - assert str(excinfo.value) == expected diff --git a/glotaran/deprecation/modules/test/test_parameter_parameter_group.py b/glotaran/deprecation/modules/test/test_parameter_parameter_group.py deleted file mode 100644 index 3794c9be5..000000000 --- a/glotaran/deprecation/modules/test/test_parameter_parameter_group.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Tests for deprecated methods in ``glotaran..parameter.ParameterGroup``.""" -from pathlib import Path -from textwrap import dedent - -from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper -from glotaran.testing.simulated_data.sequential_spectral_decay import PARAMETERS - - -def test_parameter_group_to_csv_no_stderr(tmp_path: Path): - """``ParameterGroup.to_csv`` raises deprecation warning and saves file.""" - parameter_path = tmp_path / "test_parameter.csv" - deprecation_warning_on_call_test_helper( - PARAMETERS.to_csv, args=[parameter_path.as_posix()], raise_exception=True - ) - expected = dedent( - """\ - label,value,expression,minimum,maximum,non-negative,vary,standard-error - rates.species_1,0.5,None,-inf,inf,False,True,None - rates.species_2,0.3,None,-inf,inf,False,True,None - rates.species_3,0.1,None,-inf,inf,False,True,None - irf.center,0.3,None,-inf,inf,False,True,None - irf.width,0.1,None,-inf,inf,False,True,None - """ - ) - - assert parameter_path.is_file() - assert parameter_path.read_text() == expected diff --git a/glotaran/deprecation/modules/test/test_project_scheme.py b/glotaran/deprecation/modules/test/test_project_scheme.py deleted file mode 100644 index 9a3ee528a..000000000 --- a/glotaran/deprecation/modules/test/test_project_scheme.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Test deprecated functionality in 'glotaran.project.schmeme'.""" -from __future__ import annotations - -import pytest - -from glotaran.deprecation.modules.test import deprecation_warning_on_call_test_helper -from glotaran.project.scheme import Scheme -from glotaran.testing.simulated_data.parallel_spectral_decay import DATASET -from glotaran.testing.simulated_data.parallel_spectral_decay import MODEL -from glotaran.testing.simulated_data.parallel_spectral_decay import PARAMETERS - - -def test_scheme_group_tolerance(): - """Argument ``group_tolerance`` raises deprecation and maps to ``clp_link_tolerance``.""" - model, parameters, dataset = MODEL, PARAMETERS, DATASET - - warnings, result = deprecation_warning_on_call_test_helper( - Scheme, - args=(model, parameters, {"dataset": dataset}), - kwargs={"group_tolerance": 1}, - raise_exception=True, - ) - - assert isinstance(result, Scheme) - assert result.clp_link_tolerance == 1 - assert "glotaran.project.Scheme(..., clp_link_tolerance=...)" in warnings[0].message.args[0] - - -@pytest.mark.parametrize( - "group", - (True, False), -) -def test_scheme_group(group: bool): - """Argument ``group`` raises deprecation and maps to ``dataset_groups.default.link_clp``.""" - model, parameters, dataset = MODEL, PARAMETERS, DATASET - - warnings, result = deprecation_warning_on_call_test_helper( - Scheme, - args=(model, parameters, {"dataset": dataset}), - kwargs={"group": group}, - raise_exception=True, - ) - - assert isinstance(result, Scheme) - assert result.model.dataset_group_models["default"].link_clp == group - assert "dataset_groups.default.link_clp" in warnings[0].message.args[0] - - -@pytest.mark.parametrize( - "non_negative_least_squares, expected", - ((True, "non_negative_least_squares"), (False, "variable_projection")), -) -def test_scheme_non_negative_least_squares(non_negative_least_squares: bool, expected: str): - """Argument ``non_negative_least_squares`` raises deprecation and maps to - ``dataset_groups.default.residual_function``. - """ - model, parameters, dataset = MODEL, PARAMETERS, DATASET - - warnings, result = deprecation_warning_on_call_test_helper( - Scheme, - args=(model, parameters, {"dataset": dataset}), - kwargs={"non_negative_least_squares": non_negative_least_squares}, - raise_exception=True, - ) - - assert isinstance(result, Scheme) - assert result.model.dataset_group_models["default"].residual_function == expected - assert "dataset_groups.default.residual_function" in warnings[0].message.args[0] diff --git a/glotaran/io/interface.py b/glotaran/io/interface.py index fff021402..25aaf06dd 100644 --- a/glotaran/io/interface.py +++ b/glotaran/io/interface.py @@ -23,7 +23,7 @@ import xarray as xr from glotaran.model import Model - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters from glotaran.project import Result from glotaran.project import Scheme @@ -154,8 +154,8 @@ def save_model(self, model: Model, file_name: str): """ raise NotImplementedError(f"Cannot save models with format {self.format!r}") - def load_parameters(self, file_name: str) -> ParameterGroup: - """Create a ParameterGroup instance from the specs defined in a file. + def load_parameters(self, file_name: str) -> Parameters: + """Create a Parameters instance from the specs defined in a file. **NOT IMPLEMENTED** @@ -166,8 +166,8 @@ def load_parameters(self, file_name: str) -> ParameterGroup: Returns ------- - ParameterGroup - ParameterGroup instance created from the file. + Parameters + Parameters instance created from the file. .. # noqa: DAR202 @@ -175,15 +175,15 @@ def load_parameters(self, file_name: str) -> ParameterGroup: """ raise NotImplementedError(f"Cannot read parameters with format {self.format!r}") - def save_parameters(self, parameters: ParameterGroup, file_name: str): - """Save a ParameterGroup instance to a spec file. + def save_parameters(self, parameters: Parameters, file_name: str): + """Save a Parameters instance to a spec file. **NOT IMPLEMENTED** Parameters ---------- - parameters : ParameterGroup - ParameterGroup instance to save to specs file. + parameters : Parameters + Parameters instance to save to specs file. file_name : str File to write the parameter specs to. diff --git a/glotaran/model/__init__.py b/glotaran/model/__init__.py index 6c2a6d493..95db58bb1 100644 --- a/glotaran/model/__init__.py +++ b/glotaran/model/__init__.py @@ -1,26 +1,22 @@ -"""Glotaran Model Package - -This package contains the Glotaran's base model object, the model decorators and -common model items. -""" - +"""The glotaran model package.""" +from glotaran.model.clp_constraint import OnlyConstraint +from glotaran.model.clp_constraint import ZeroConstraint from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import Constraint -from glotaran.model.constraint import OnlyConstraint -from glotaran.model.constraint import ZeroConstraint +from glotaran.model.clp_relation import ClpRelation from glotaran.model.dataset_group import DatasetGroup -from glotaran.model.dataset_group import DatasetGroupModel from glotaran.model.dataset_model import DatasetModel -from glotaran.model.item import model_item -from glotaran.model.item import model_item_typed +from glotaran.model.dataset_model import get_dataset_model_model_dimension +from glotaran.model.dataset_model import is_dataset_model_index_dependent +from glotaran.model.item import ItemIssue +from glotaran.model.item import ModelItem +from glotaran.model.item import ModelItemType +from glotaran.model.item import ModelItemTyped +from glotaran.model.item import ParameterType +from glotaran.model.item import attribute +from glotaran.model.item import fill_item +from glotaran.model.item import item from glotaran.model.megacomplex import Megacomplex from glotaran.model.megacomplex import megacomplex from glotaran.model.model import Model -from glotaran.model.relation import Relation -from glotaran.model.util import ModelError +from glotaran.model.model import ModelError from glotaran.model.weight import Weight -from glotaran.plugin_system.megacomplex_registration import get_megacomplex -from glotaran.plugin_system.megacomplex_registration import is_known_megacomplex -from glotaran.plugin_system.megacomplex_registration import known_megacomplex_names -from glotaran.plugin_system.megacomplex_registration import megacomplex_plugin_table -from glotaran.plugin_system.megacomplex_registration import set_megacomplex_plugin diff --git a/glotaran/model/clp_constraint.py b/glotaran/model/clp_constraint.py new file mode 100644 index 000000000..ae66a17b0 --- /dev/null +++ b/glotaran/model/clp_constraint.py @@ -0,0 +1,44 @@ +"""This module contains clp constraint items.""" +from __future__ import annotations + +from glotaran.model.interval_item import IntervalItem +from glotaran.model.item import TypedItem +from glotaran.model.item import item + + +@item +class ClpConstraint(TypedItem, IntervalItem): + """Baseclass for clp constraints. + + There are two types: zero and equal. See the documentation of + the respective classes for details. + """ + + +@item +class ZeroConstraint(ClpConstraint): + """Constraints the target to 0 in the given interval.""" + + type: str = "zero" + target: str + + +@item +class OnlyConstraint(ZeroConstraint): + """Constraints the target to 0 outside the given interval.""" + + type: str = "only" + + def applies(self, index: float | None) -> bool: + """Check if the constraint applies on this index. + + Parameters + ---------- + index : float + The index. + + Returns + ------- + bool + """ + return not super().applies(index) diff --git a/glotaran/model/clp_penalties.py b/glotaran/model/clp_penalties.py index 23c7f52c5..50e028f4d 100644 --- a/glotaran/model/clp_penalties.py +++ b/glotaran/model/clp_penalties.py @@ -1,160 +1,31 @@ -"""This package contains compartment constraint items.""" +"""This module contains clp penalty items.""" from __future__ import annotations -from typing import TYPE_CHECKING -from typing import List -from typing import Tuple +from glotaran.model.item import ParameterType +from glotaran.model.item import TypedItem +from glotaran.model.item import item -import numpy as np -import xarray as xr -from glotaran.model.item import model_item -from glotaran.parameter import Parameter +@item +class ClpPenalty(TypedItem): + """Baseclass for clp penalties.""" -if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Any - from glotaran.model.model import Model - from glotaran.parameter import ParameterGroup +@item +class EqualAreaPenalty(ClpPenalty): + """Forces the area of 2 clp to be the same. - -@model_item( - properties={ - "source": str, - "source_intervals": List[Tuple[float, float]], - "target": str, - "target_intervals": List[Tuple[float, float]], - "parameter": Parameter, - "weight": str, - }, - has_label=False, -) -class EqualAreaPenalty: - """An equal area constraint adds a the difference of the sum of a + An equal area constraint adds a the difference of the sum of a compartments in the e matrix in one or more intervals to the scaled sum of the e matrix of one or more target compartments to residual. The additional - residual is scaled with the weight.""" - - def applies(self, index: Any) -> bool: - """ - Returns true if the index is in one of the intervals. - - Parameters - ---------- - index : - - Returns - ------- - applies : bool - - """ - - def applies(interval): - return interval[0] <= index <= interval[1] - - if isinstance(self.interval, tuple): - return applies(self.interval) - return any(applies(i) for i in self.interval) - - -def has_spectral_penalties(model: Model) -> bool: - return len(model.clp_area_penalties) != 0 - - -def apply_spectral_penalties( - model: Model, - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - clps: dict[str, list[np.ndarray]], - matrices: dict[str, np.ndarray | list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, -) -> np.ndarray: - - # TODO: seems to duplicate calculate_clp_penalties - penalties = [] - for penalty in model.clp_area_penalties: - - penalty = penalty.fill(model, parameters) - source_area = _get_area( - model.index_dependent(), - model.global_dimension, - clp_labels, - clps, - data, - group_tolerance, - penalty.source_intervals, - penalty.source, - ) - - target_area = _get_area( - model.index_dependent(), - model.global_dimension, - clp_labels, - clps, - data, - group_tolerance, - penalty.target_intervals, - penalty.target, - ) - - area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) - penalties.append(area_penalty * penalty.weight) - return np.asarray(penalties) - - -def _get_area( - index_dependent: bool, - global_dimension: str, - clp_labels: dict[str, list[list[str]]], - clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, - intervals: list[tuple[float, float]], - compartment: str, -) -> np.ndarray: - area = [] - area_indices = [] - - for label, dataset in data.items(): - global_axis = dataset.coords[global_dimension] - for interval in intervals: - if interval[0] > global_axis[-1]: - # interval not in this dataset - continue - - start_idx, end_idx = _get_idx_from_interval(interval, global_axis) - for i in range(start_idx, end_idx + 1): - index_clp_labels = clp_labels[label][i] if index_dependent else clp_labels[label] - if compartment in index_clp_labels: - area.append(clps[label][i][index_clp_labels.index(compartment)]) - area_indices.append(global_axis[i]) - - return np.asarray(area) # TODO: normalize for distance on global axis - - -def _get_idx_from_interval( - interval: tuple[float, float], axis: Sequence[float] | np.ndarray -) -> tuple[int, int]: - """Retrieves start and end index of an interval on some axis - - Parameters - ---------- - interval : A tuple of floats with begin and end of the interval - axis : Array like object which can be cast to np.array - - Returns - ------- - start, end : tuple of int - + residual is scaled with the weight. """ - axis_array = np.array(axis) - start = 0 if np.isinf(interval[0]) else np.abs(axis_array - interval[0]).argmin() - - end = ( - axis_array.size - 1 if np.isinf(interval[1]) else np.abs(axis_array - interval[1]).argmin() - ) - return start, end + type: str = "equal_area" + source: str + source_intervals: list[tuple[float, float]] + target: str + target_intervals: list[tuple[float, float]] + parameter: ParameterType + weight: float diff --git a/glotaran/model/clp_penalties.pyi b/glotaran/model/clp_penalties.pyi deleted file mode 100644 index 54fca6b62..000000000 --- a/glotaran/model/clp_penalties.pyi +++ /dev/null @@ -1,44 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -import numpy as np -import xarray as xr - -from glotaran.model.item import model_item -from glotaran.model.model import Model -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup - -class EqualAreaPenalty: - source: str - source_intervals: list[tuple[float, float]] - target: str - target_intervals: list[tuple[float, float]] - parameter: Parameter - weight: str - def applies(self, index: Any) -> bool: ... - def fill(self, model: Model, parameters: ParameterGroup | None) -> EqualAreaPenalty: ... - -def has_spectral_penalties(model: Model) -> bool: ... -def apply_spectral_penalties( - model: Model, - parameters: ParameterGroup, - clp_labels: dict[str, list[str] | list[list[str]]], - clps: dict[str, list[np.ndarray]], - matrices: dict[str, np.ndarray | list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, -) -> np.ndarray: ... -def _get_area( - index_dependent: bool, - global_dimension: str, - clp_labels: dict[str, list[list[str]]], - clps: dict[str, list[np.ndarray]], - data: dict[str, xr.Dataset], - group_tolerance: float, - intervals: list[tuple[float, float]], - compartment: str, -) -> np.ndarray: ... -def _get_idx_from_interval( - interval: tuple[float, float], axis: Sequence[float] | np.ndarray -) -> tuple[int, int]: ... diff --git a/glotaran/model/clp_relation.py b/glotaran/model/clp_relation.py new file mode 100644 index 000000000..d97e1aae0 --- /dev/null +++ b/glotaran/model/clp_relation.py @@ -0,0 +1,18 @@ +"""This module contains clp relation items.""" +from __future__ import annotations + +from glotaran.model.interval_item import IntervalItem +from glotaran.model.item import ParameterType +from glotaran.model.item import item + + +@item +class ClpRelation(IntervalItem): + """Applies a relation between two clps. + + The relation is applied as :math:`target = parameter * source`. + """ + + source: str + target: str + parameter: ParameterType diff --git a/glotaran/model/constraint.py b/glotaran/model/constraint.py deleted file mode 100644 index 0c3fc80d2..000000000 --- a/glotaran/model/constraint.py +++ /dev/null @@ -1,61 +0,0 @@ -"""This package contains compartment constraint items.""" -from __future__ import annotations - -from glotaran.model.interval_property import IntervalProperty -from glotaran.model.item import model_item -from glotaran.model.item import model_item_typed - - -@model_item( - properties={ - "target": str, - }, - has_type=True, - has_label=False, -) -class OnlyConstraint(IntervalProperty): - """A only constraint sets the calculated matrix row of a compartment to 0 - outside the given intervals.""" - - def applies(self, value: float) -> bool: - """ - Returns true if ``value`` is in one of the intervals. - - Parameters - ---------- - index : float - - Returns - ------- - applies : bool - - """ - return not super().applies(value) - - -@model_item( - properties={ - "target": str, - }, - has_type=True, - has_label=False, -) -class ZeroConstraint(IntervalProperty): - """A zero constraint sets the calculated matrix row of a compartment to 0 - in the given intervals.""" - - -@model_item_typed( - types={ - "only": OnlyConstraint, - "zero": ZeroConstraint, - }, - has_label=False, -) -class Constraint: - """A constraint is applied on one clp on one or many - intervals on the estimated axis type. - - There are two types: zero and equal. See the documentation of - the respective classes for details. - """ diff --git a/glotaran/model/dataset_group.py b/glotaran/model/dataset_group.py index 4b599e431..ee6c0c197 100644 --- a/glotaran/model/dataset_group.py +++ b/glotaran/model/dataset_group.py @@ -1,21 +1,27 @@ +"""This module contains the dataset group.""" from __future__ import annotations -from dataclasses import dataclass -from dataclasses import field from typing import TYPE_CHECKING from typing import Literal import xarray as xr +from attrs import define +from attrs import field from glotaran.model.dataset_model import DatasetModel +from glotaran.model.dataset_model import get_dataset_model_model_dimension +from glotaran.model.dataset_model import has_dataset_model_global_model +from glotaran.model.item import ModelItem +from glotaran.model.item import fill_item +from glotaran.model.item import item if TYPE_CHECKING: from glotaran.model.model import Model - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters -@dataclass -class DatasetGroupModel: +@item +class DatasetGroupModel(ModelItem): """A group of datasets which will evaluated independently.""" residual_function: Literal[ @@ -27,7 +33,7 @@ class DatasetGroupModel: """Whether to link the clp parameter.""" -@dataclass +@define class DatasetGroup: """A dataset group for optimization.""" @@ -38,22 +44,45 @@ class DatasetGroup: """Whether to link the clp parameter.""" model: Model - parameters: ParameterGroup | None = None + parameters: Parameters | None = None - dataset_models: dict[str, DatasetModel] = field(default_factory=dict) + dataset_models: dict[str, DatasetModel] = field(factory=dict) - def set_parameters(self, parameters: ParameterGroup): + def set_parameters(self, parameters: Parameters): + """Set the group parameters. + + Parameters + ---------- + parameters : Parameters + The parameters. + """ self.parameters = parameters for label in self.dataset_models: - self.dataset_models[label] = self.model.dataset[label].fill(self.model, parameters) - - def is_linkable(self, parameters: ParameterGroup, data: dict[str, xr.Dataset]) -> bool: - if any(d.has_global_model() for d in self.dataset_models.values()): + self.dataset_models[label] = fill_item( + self.model.dataset[label], self.model, parameters + ) + + def is_linkable(self, parameters: Parameters, data: dict[str, xr.Dataset]) -> bool: + """Check if the group is linkable. + + Parameters + ---------- + parameters : Parameters + A parameter set parameters. + data : dict[str, xr.Dataset] + A the data to link. + + Returns + ------- + bool + """ + if any(has_dataset_model_global_model(d) for d in self.dataset_models.values()): return False dataset_models = [ - self.model.dataset[label].fill(self.model, parameters) for label in self.dataset_models + fill_item(self.model.dataset[label], self.model, parameters) + for label in self.dataset_models ] - model_dimensions = {d.get_model_dimension() for d in dataset_models} + model_dimensions = {get_dataset_model_model_dimension(d) for d in dataset_models} if len(model_dimensions) != 1: return False global_dimensions = set() diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index b07224cde..586d4b270 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -1,177 +1,340 @@ -"""The DatasetModel class.""" +"""This module contains the dataset model.""" from __future__ import annotations -import contextlib -from collections import Counter from typing import TYPE_CHECKING +from typing import Generator import xarray as xr -from glotaran.model.item import model_item -from glotaran.model.item import model_item_validator +from glotaran.model.item import ItemIssue +from glotaran.model.item import ModelItem +from glotaran.model.item import ModelItemType +from glotaran.model.item import ParameterType +from glotaran.model.item import attribute +from glotaran.model.item import item +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import is_exclusive +from glotaran.model.megacomplex import is_unique if TYPE_CHECKING: - from typing import Any - from typing import Generator - - from glotaran.model.megacomplex import Megacomplex from glotaran.model.model import Model from glotaran.parameter import Parameter + from glotaran.parameter import Parameters -def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: - """Create dataset model type for a model.""" - - @model_item(properties=properties) - class ModelDatasetModel(DatasetModel): - pass - - return ModelDatasetModel +class ExclusiveMegacomplexIssue(ItemIssue): + """Issue for exclusive megacomplexes.""" + def __init__(self, label: str, megacomplex_type: str, is_global: bool): + """Create an ExclusiveMegacomplexIssue. -class DatasetModel: - """A `DatasetModel` describes a dataset in terms of a glotaran model. - It contains references to model items which describe the physical model for - a given dataset. - - A general dataset descriptor assigns one or more megacomplexes and a scale - parameter. - """ + Parameters + ---------- + label : str + The megacomplex label. + megacomplex_type : str + The megacomplex type. + is_global : bool + Whether the megacomplex is global. + """ + self._label = label + self._type = megacomplex_type + self._is_global = is_global - def iterate_megacomplexes( - self, - ) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: - """Iterates the dataset model's megacomplexes.""" - for i, megacomplex in enumerate(self.megacomplex): - scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None - yield scale, megacomplex - - def iterate_global_megacomplexes( - self, - ) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: - """Iterates the dataset model's global megacomplexes.""" - for i, megacomplex in enumerate(self.global_megacomplex): - scale = ( - self.global_megacomplex_scale[i] - if self.global_megacomplex_scale is not None - else None - ) - yield scale, megacomplex - - def get_model_dimension(self) -> str: - """Returns the dataset model's model dimension.""" - if len(self.megacomplex) == 0: - raise ValueError(f"No megacomplex set for dataset model '{self.label}'") - if isinstance(self.megacomplex[0], str): - raise ValueError(f"Dataset model '{self.label}' was not filled") - model_dimension = self.megacomplex[0].dimension - if any(model_dimension != m.dimension for m in self.megacomplex): - raise ValueError( - f"Megacomplex dimensions do not match for dataset model '{self.label}'." - ) - return model_dimension - - def finalize_data(self, dataset: xr.Dataset): - """Finalize a dataset by applying all megacomplex finalize methods.""" - is_full_model = self.has_global_model() - for megacomplex in self.megacomplex: - megacomplex.finalize_data(self, dataset, is_full_model=is_full_model) - if is_full_model: - for megacomplex in self.global_megacomplex: - megacomplex.finalize_data( - self, dataset, is_full_model=is_full_model, as_global=True - ) + def to_string(self) -> str: + """Get the issue as string. - def overwrite_model_dimension(self, model_dimension: str) -> None: - """Overwrites the dataset model's model dimension.""" - self._model_dimension = model_dimension + Returns + ------- + str + """ + return ( + f"Exclusive {'global ' if self._is_global else ''}megacomplex '{self._label}' of " + f"type '{self._type}' cannot be combined with other megacomplexes." + ) - def is_index_dependent(self) -> bool: - """Indicates if the dataset model is index dependent.""" - if self.force_index_dependent: - return True - return any(m.index_dependent(self) for m in self.megacomplex) - def has_global_model(self) -> bool: - """Indicates if the dataset model can model the global dimension.""" - return self.global_megacomplex is not None and len(self.global_megacomplex) != 0 +class UniqueMegacomplexIssue(ItemIssue): + """Issue for unique megacomplexes.""" - @model_item_validator(False) - def ensure_unique_megacomplexes(self, model: Model) -> list[str]: - """Ensure that unique megacomplexes are only used once per dataset. + def __init__(self, label: str, megacomplex_type: str, is_global: bool): + """Create a UniqueMegacomplexIssue. Parameters ---------- - model : Model - Model object using this dataset model. - - Returns - ------- - list[str] - Error messages to be shown when the model gets validated. + label : str + The megacomplex label. + megacomplex_type : str + The megacomplex type. + is_global : bool + Whether the megacomplex is global. """ - errors = [] - - def get_unique_errors(megacomplexes: list[str], is_global: bool) -> list[str]: - unique_types = [] - for megacomplex_name in megacomplexes: - with contextlib.suppress(KeyError): - megacomplex_instance = model.megacomplex[megacomplex_name] - if type(megacomplex_instance).glotaran_unique(): - type_name = megacomplex_instance.type or megacomplex_instance.name - unique_types.append(type_name) - this_errors = [ - f"Multiple instances of unique{' global ' if is_global else ' '}" - f"megacomplex type {type_name!r} in dataset {self.label!r}" - for type_name, count in Counter(unique_types).most_common() - if count > 1 - ] - - return this_errors - - if self.megacomplex: - errors += get_unique_errors(self.megacomplex, False) - if self.global_megacomplex: - errors += get_unique_errors(self.global_megacomplex, True) - - return errors - - @model_item_validator(False) - def ensure_exclusive_megacomplexes(self, model: Model) -> list[str]: - """Ensure that exclusive megacomplexes are the only megacomplex in the dataset model. + self._label = label + self._type = megacomplex_type + self._is_global = is_global - Parameters - ---------- - model : Model - Model object using this dataset model. + def to_string(self): + """Get the issue as string. Returns ------- - list[str] - Error messages to be shown when the model gets validated. + str """ + return ( + f"Unique {'global ' if self._is_global else ''}megacomplex '{self._label}' of " + f"type '{self._type}' can only be used once per dataset." + ) + + +def get_megacomplex_issues( + value: list[str | Megacomplex] | None, model: Model, is_global: bool +) -> list[ItemIssue]: + """Get issues for megacomplexes. + + Parameters + ---------- + value: list[str | Megacomplex] | None + A list of megacomplexes. + model: Model + The model. + is_global: bool + Whether the megacomplexes are global. + + Returns + ------- + list[ItemIssue] + """ + issues: list[ItemIssue] = [] + + if value is not None: + labels = [v if isinstance(v, str) else v.label for v in value] + megacomplexes = [model.megacomplex[label] for label in labels] + for megacomplex in megacomplexes: + megacomplex_type = megacomplex.__class__ + if is_exclusive(megacomplex_type) and len(megacomplexes) > 1: + issues.append( + ExclusiveMegacomplexIssue(megacomplex.label, megacomplex.type, is_global) + ) + if ( + is_unique(megacomplex_type) + and len([m for m in megacomplexes if m.__class__ is megacomplex_type]) > 1 + ): + issues.append( + UniqueMegacomplexIssue(megacomplex.label, megacomplex.type, is_global) + ) + return issues + + +def validate_megacomplexes( + value: list[str | Megacomplex], + dataset_model: DatasetModel, + model: Model, + parameters: Parameters | None, +) -> list[ItemIssue]: + """Get issues for dataset model megacomplexes. + + Parameters + ---------- + value: list[str | Megacomplex] + A list of megacomplexes. + dataset_model: DatasetModel + The dataset model. + model: Model + The model. + parameters: Parameters | None, + The parameters. + + Returns + ------- + list[ItemIssue] + """ + return get_megacomplex_issues(value, model, False) + + +def validate_global_megacomplexes( + value: list[str | Megacomplex] | None, + dataset_model: DatasetModel, + model: Model, + parameters: Parameters | None, +) -> list[ItemIssue]: + """Get issues for dataset model global megacomplexes. + + Parameters + ---------- + value: list[str | Megacomplex] | None + A list of megacomplexes. + dataset_model: DatasetModel + The dataset model. + model: Model + The model. + parameters: Parameters | None, + The parameters. + + Returns + ------- + list[ItemIssue] + """ + return get_megacomplex_issues(value, model, False) + + +@item +class DatasetModel(ModelItem): + """A model for datasets.""" + + group: str = "default" + force_index_dependent: bool = False + megacomplex: list[ModelItemType[Megacomplex]] = attribute( + validator=validate_megacomplexes # type:ignore[arg-type] + ) + megacomplex_scale: list[ParameterType] | None = None + global_megacomplex: list[ModelItemType[Megacomplex]] | None = attribute( + alias="megacomplex", + default=None, + validator=validate_global_megacomplexes, # type:ignore[arg-type] + ) + global_megacomplex_scale: list[ParameterType] | None = None + scale: ParameterType | None = None + + +def is_dataset_model_index_dependent(dataset_model: DatasetModel) -> bool: + """Check if the dataset model is index dependent. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Returns + ------- + bool + """ + if dataset_model.force_index_dependent: + return True + return any( + m.index_dependent(dataset_model) # type:ignore[union-attr] + for m in dataset_model.megacomplex + ) - errors = [] - def get_exclusive_errors(megacomplexes: list[str]) -> list[str]: - with contextlib.suppress(StopIteration): - exclusive_megacomplex = next( - model.megacomplex[label] - for label in megacomplexes - if label in model.megacomplex - and type(model.megacomplex[label]).glotaran_exclusive() - ) - if len(self.megacomplex) != 1: - return [ - f"Megacomplex '{type(exclusive_megacomplex)}' is exclusive and cannot be " - f"combined with other megacomplex in dataset model '{self.label}'." - ] - return [] - - if self.megacomplex: - errors += get_exclusive_errors(self.megacomplex) - if self.global_megacomplex: - errors += get_exclusive_errors(self.global_megacomplex) - - return errors +def has_dataset_model_global_model(dataset_model: DatasetModel) -> bool: + """Check if the dataset model can model the global dimension. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Returns + ------- + bool + """ + return ( + dataset_model.global_megacomplex is not None and len(dataset_model.global_megacomplex) != 0 + ) + + +def get_dataset_model_model_dimension(dataset_model: DatasetModel) -> str: + """Get the dataset model's model dimension. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Returns + ------- + str + + Raises + ------ + ValueError + Raised if the dataset model does not have megacomplexes or if it is not filled. + """ + if len(dataset_model.megacomplex) == 0: + raise ValueError(f"No megacomplex set for dataset model '{dataset_model.label}'.") + if any(isinstance(m, str) for m in dataset_model.megacomplex): + raise ValueError(f"Dataset model '{dataset_model.label}' was not filled.") + model_dimension: str = dataset_model.megacomplex[ + 0 + ].dimension # type:ignore[union-attr, assignment] + if any( + model_dimension != m.dimension # type:ignore[union-attr] + for m in dataset_model.megacomplex + ): + raise ValueError( + f"Megacomplex dimensions do not match for dataset model '{dataset_model.label}'." + ) + return model_dimension + + +def iterate_dataset_model_megacomplexes( + dataset_model: DatasetModel, +) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: + """Iterate the dataset model's megacomplexes. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Yields + ------ + tuple[Parameter | str | None, Megacomplex | str] + A scale and megacomplex. + """ + for i, megacomplex in enumerate(dataset_model.megacomplex): + scale = ( + dataset_model.megacomplex_scale[i] + if dataset_model.megacomplex_scale is not None + else None + ) + yield scale, megacomplex + + +def iterate_dataset_model_global_megacomplexes( + dataset_model: DatasetModel, +) -> Generator[tuple[Parameter | str | None, Megacomplex | str], None, None]: + """Iterate the dataset model's global megacomplexes. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Yields + ------ + tuple[Parameter | str | None, Megacomplex | str] + A scale and megacomplex. + """ + if dataset_model.global_megacomplex is None: + return + for i, megacomplex in enumerate(dataset_model.global_megacomplex): + scale = ( + dataset_model.global_megacomplex_scale[i] + if dataset_model.global_megacomplex_scale is not None + else None + ) + yield scale, megacomplex + + +def finalize_dataset_model(dataset_model: DatasetModel, dataset: xr.Dataset): + """Finalize a dataset by applying all megacomplex finalize methods. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + dataset: xr.Dataset + The dataset. + """ + is_full_model = has_dataset_model_global_model(dataset_model) + for megacomplex in dataset_model.megacomplex: + megacomplex.finalize_data( # type:ignore[union-attr] + dataset_model, dataset, is_full_model=is_full_model + ) + if is_full_model and dataset_model.global_megacomplex is not None: + for megacomplex in dataset_model.global_megacomplex: + megacomplex.finalize_data( # type:ignore[union-attr] + dataset_model, dataset, is_full_model=is_full_model, as_global=True + ) diff --git a/glotaran/model/dataset_model.pyi b/glotaran/model/dataset_model.pyi deleted file mode 100644 index 989a0f050..000000000 --- a/glotaran/model/dataset_model.pyi +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator -from collections.abc import Hashable -from typing import Any - -import numpy as np -import xarray as xr - -from glotaran.model.megacomplex import Megacomplex -from glotaran.model.model import Model -from glotaran.parameter import Parameter - -def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: ... - -class DatasetModel: - - label: str - megacomplex: list[str] - megacomplex_scale: list[Parameter] | None - global_megacomplex: list[str] - global_megacomplex_scale: list[Parameter] | None - scale: Parameter | None - _coords: dict[Hashable, np.ndarray] - def iterate_megacomplexes( - self, - ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... - def iterate_global_megacomplexes( - self, - ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: ... - def get_model_dimension(self) -> str: ... - def finalize_data(self, dataset: xr.Dataset) -> None: ... - def overwrite_model_dimension(self, model_dimension: str) -> None: ... - def get_global_dimension(self) -> str: ... - def overwrite_global_dimension(self, global_dimension: str) -> None: ... - def swap_dimensions(self) -> None: ... - def set_data(self, dataset: xr.Dataset) -> DatasetModel: ... - def get_data(self) -> np.ndarray: ... - def get_weight(self) -> np.ndarray | None: ... - def is_index_dependent(self) -> bool: ... - def overwrite_index_dependent(self, index_dependent: bool): ... - def has_global_model(self) -> bool: ... - def set_coordinates(self, coords: dict[str, np.ndarray]): ... - def get_coordinates(self) -> dict[Hashable, np.ndarray]: ... - def get_model_axis(self) -> np.ndarray: ... - def get_global_axis(self) -> np.ndarray: ... - def ensure_unique_megacomplexes(self, model: Model) -> list[str]: ... - def ensure_exclusive_megacomplexes(self, model: Model) -> list[str]: ... diff --git a/glotaran/model/interval_item.py b/glotaran/model/interval_item.py new file mode 100644 index 000000000..a9d508d9f --- /dev/null +++ b/glotaran/model/interval_item.py @@ -0,0 +1,47 @@ +"""This module contains the interval item.""" +from __future__ import annotations + +from glotaran.model.item import Item +from glotaran.model.item import item + + +@item +class IntervalItem(Item): + """An item with an interval.""" + + interval: tuple[float, float] | list[tuple[float, float]] | None = None + + def has_interval(self) -> bool: + """Check if intervals are defined. + + Returns + ------- + bool + """ + return self.interval is not None + + def applies(self, index: float | None) -> bool: + """Check if the index is in the intervals. + + Parameters + ---------- + index : float + The index. + + Returns + ------- + bool + + """ + if self.interval is None or index is None: + return True + + def applies(interval: tuple[float, float]): + lower, upper = interval[0], interval[1] + if lower > upper: + lower, upper = upper, lower + return lower <= index <= upper # type:ignore[operator] + + if isinstance(self.interval, tuple): + return applies(self.interval) + return any(applies(i) for i in self.interval) diff --git a/glotaran/model/interval_property.py b/glotaran/model/interval_property.py deleted file mode 100644 index 62e30bb87..000000000 --- a/glotaran/model/interval_property.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Helper functions.""" -from __future__ import annotations - -from typing import List -from typing import Tuple - -from glotaran.model.item import model_item - - -@model_item( - properties={ - "interval": {"type": List[Tuple[float, float]], "default": None, "allow_none": True}, - }, - has_label=False, -) -class IntervalProperty: - """Applies a relation between clps as - - :math:`source = parameter * target`. - """ - - def has_interval(self) -> bool: - return self.interval is not None - - def applies(self, value: float | None) -> bool: - """ - Returns true if ``value`` is in one of the intervals. - - Parameters - ---------- - value : float - - Returns - ------- - applies : bool - - """ - if self.interval is None: - return True - - def applies(interval): - return interval[0] <= value <= interval[1] - - if isinstance(self.interval, tuple): - return applies(self.interval) - return any(applies(i) for i in self.interval) diff --git a/glotaran/model/item.py b/glotaran/model/item.py index a4732aa5e..6e8e997d4 100644 --- a/glotaran/model/item.py +++ b/glotaran/model/item.py @@ -1,393 +1,686 @@ -"""The model item decorator.""" +"""This module contains the items.""" from __future__ import annotations -import copy +import contextlib +from inspect import getmro +from inspect import isclass from textwrap import indent +from types import NoneType +from types import UnionType from typing import TYPE_CHECKING +from typing import Any from typing import Callable -from typing import List +from typing import ClassVar +from typing import Generator +from typing import Iterator from typing import Type - -from glotaran.model.property import ModelProperty -from glotaran.model.util import wrap_func_as_method +from typing import TypeAlias +from typing import TypeVar +from typing import Union +from typing import get_args +from typing import get_origin + +from attrs import NOTHING +from attrs import Attribute +from attrs import define +from attrs import evolve +from attrs import field +from attrs import fields +from attrs import resolve_types + +from glotaran.parameter import Parameter +from glotaran.parameter import Parameters from glotaran.utils.ipython import MarkdownStr if TYPE_CHECKING: - from typing import Any - from glotaran.model.model import Model - from glotaran.parameter import ParameterGroup - Validator = Callable[ - [Type[object], Type[Model]], - List[str], - ] - ValidatorParameter = Callable[ - [Type[object], Type[Model], Type[ParameterGroup]], - List[str], - ] +META_ALIAS = "__glotaran_alias__" +META_VALIDATOR = "__glotaran_validator__" + + +class ItemIssue: + """Baseclass for item issues.""" + + def to_string(self) -> str: + """Get the issue as string. + + Returns + ------- + str + + .. # noqa: DAR202 + .. # noqa: DAR401 + """ + raise NotImplementedError + + def __rep__(self) -> str: + """Get the representation.""" + return self.to_string() + + +class ModelItemIssue(ItemIssue): + """Issue for missing model items.""" + + def __init__(self, item_name: str, label: str): + """Create a model issue. + + Parameters + ---------- + item_name : str + The name of the item. + label : str + The item label. + """ + self._item_name = item_name + self._label = label + + def to_string(self) -> str: + """Get the issue as string. + + Returns + ------- + str + """ + return f"Missing model item '{self._item_name}' with label '{self._label}'." + + +class ParameterIssue(ItemIssue): + """Issue for missing parameters.""" + + def __init__(self, label: str): + """Create a parameter issue. + + Parameters + ---------- + label : str + The parameter label. + """ + self._label = label + + def to_string(self) -> str: + """Get the issue as string. + + Returns + ------- + str + """ + return f"Missing parameter with label '{self._label}'." + + +class Item: + """A baseclass for items.""" + + +@define(kw_only=True, slots=False) +class ModelItem(Item): + """An item with a label.""" + + label: str + +@define(kw_only=True, slots=False) +class TypedItem(Item): + """An item with a type.""" -def model_item( - properties: None | dict[str, dict[str, Any]] = None, - has_type: bool = False, - has_label: bool = True, -) -> Callable: - """The `@model_item` decorator adds the given properties to the class. Further it adds - classmethods for deserialization, validation and printing. + type: str + __item_types__: ClassVar[dict[str, Type]] - By default, a `label` property is added. + @classmethod + def _register_item_class(cls): + """Register a class as type.""" + item_type = cls.get_item_type() + if item_type is not NOTHING: + cls.__item_types__[item_type] = cls + + @classmethod + def get_item_type(cls) -> str: + """Get the type string. + + Returns + ------- + str + """ + return fields(cls).type.default + + @classmethod + def get_item_types(cls) -> list[str]: + """Get all type strings. + + Returns + ------- + list[str] + """ + return list(cls.__item_types__.keys()) + + @classmethod + def get_item_type_class(cls, item_type: str) -> Type: + """Get the type for a type string. + + Parameters + ---------- + item_type: str + The type string. + Returns + ------- + Type + """ + return cls.__item_types__[item_type] + + +@define(kw_only=True, slots=False) +class ModelItemTyped(TypedItem, ModelItem): + """A model item with a type.""" + + +ItemT = TypeVar("ItemT", bound="Item") +ModelItemT = TypeVar("ModelItemT", bound="ModelItem") - The `properties` dictionary contains the name of the properties as keys. The values must be - either a `type` or dictionary with the following values: +ParameterType: TypeAlias = Parameter | str +ModelItemType: TypeAlias = ModelItemT | str # type:ignore[operator] - * type: a `type` (required) - * doc: a string for documentation (optional) - * default: a default value (optional) - * allow_none: if `True`, the property can be set to None (optional) - Classes with the `model_item` decorator intended to be used in glotaran models. +def item_to_markdown( + item: Item, parameters: Parameters | None = None, initial_parameters: Parameters | None = None +) -> MarkdownStr: + """Get the item as markdown string. Parameters ---------- - properties : - A dictionary of property names and options. - has_type : - If true, a type property will added. Used for model attributes, which - can have more then one type. - has_label : - If false no label property will be added. + item: Item + The item. + parameters: Parameters | None + The parameters. + initial_parameters: Parameters | None + The initial parameters. + + Returns + ------- + MarkdownStr """ + md = "\n" + for attr in fields(item.__class__): + name = attr.name + value = getattr(item, name) + if value is None: + continue + + structure, item_type = strip_type_and_structure_from_attribute(attr) + if item_type is Parameter and parameters is not None: + if structure is dict: + value = { + k: parameters.get(v.label if isinstance(v, Parameter) else v).markdown( + parameters, initial_parameters + ) + for k, v in value.items() + } + elif structure is list: + value = [ + parameters.get(v.label if isinstance(v, Parameter) else v).markdown( + parameters, initial_parameters + ) + for v in value + ] + else: + value = parameters.get( + value.label if isinstance(value, Parameter) else value + ).markdown(parameters, initial_parameters) - if properties is None: - properties = {} + property_md = indent(f"* *{name.replace('_', ' ').title()}*: {value}\n", " ") - def decorator(cls): + md += property_md - setattr(cls, "_glotaran_has_label", has_label) - setattr(cls, "_glotaran_model_item", True) + return MarkdownStr(md) - # store for later sanity checking - if not hasattr(cls, "_glotaran_properties"): - setattr(cls, "_glotaran_properties", []) - if has_label: - doc = f"The label of {cls.__name__} item." - prop = ModelProperty(cls, "label", str, doc, None, False) - setattr(cls, "label", prop) - getattr(cls, "_glotaran_properties").append("label") - if has_type: - doc = f"The type string of {cls.__name__}." - prop = ModelProperty(cls, "type", str, doc, None, True) - setattr(cls, "type", prop) - getattr(cls, "_glotaran_properties").append("type") - else: - setattr( - cls, - "_glotaran_properties", - list(getattr(cls, "_glotaran_properties")), - ) +def iterate_attributes_of_type( + item: type[Item], attr_type: type +) -> Generator[Attribute, None, None]: + """Get attributes of type from an item type. - for name, options in properties.items(): - if not isinstance(options, dict): - options = {"type": options} - prop = ModelProperty( - cls, - name, - options.get("type"), - options.get("doc", f"{name}"), - options.get("default", None), - options.get("allow_none", False), - ) - setattr(cls, name, prop) - if name not in getattr(cls, "_glotaran_properties"): - getattr(cls, "_glotaran_properties").append(name) + Parameters + ---------- + item: type[Item] + The item type. + attr_type: type + The attribute type. + + Yields + ------ + Attribute + The attributes. + """ + for attr in fields(item): + _, item_type = strip_type_and_structure_from_attribute(attr) + with contextlib.suppress(TypeError): + # issubclass does for some reason not work with e.g. tuple as item_type + # and Parameter as attr_type + if isclass(item_type) and issubclass(item_type, attr_type): + yield attr - validators = _get_validators(cls) - setattr(cls, "_glotaran_validators", validators) - init = _init_factory(cls) - setattr(cls, "__init__", init) +def model_attributes( + item: type[Item], with_alias: bool = True +) -> Generator[Attribute, None, None]: + """Get model attributes from an item type. - from_dict = _from_dict_factory(cls) - setattr(cls, "from_dict", from_dict) + Parameters + ---------- + item: type[Item] + The item type. + with_alias: bool + Whether to return aliased attributes. + + Yields + ------ + Attribute + The model attributes. + """ + for attr in iterate_attributes_of_type(item, ModelItem): + if with_alias or META_ALIAS not in attr.metadata: + yield attr - validate = _validation_factory(cls) - setattr(cls, "validate", validate) - as_dict = _as_dict_factory(cls) - setattr(cls, "as_dict", as_dict) +def parameter_attributes(item: type[Item]) -> Generator[Attribute, None, None]: + """Get parameter attributes from an item type. - get_state = _get_state_factory(cls) - setattr(cls, "__getstate__", get_state) + Parameters + ---------- + item: type[Item] + The item type. - set_state = _set_state_factory(cls) - setattr(cls, "__setstate__", set_state) + Yields + ------ + Attribute + The parameter attributes. + """ + yield from iterate_attributes_of_type(item, Parameter) - fill = _fill_factory(cls) - setattr(cls, "fill", fill) - markdown = _markdown_factory(cls) - setattr(cls, "markdown", markdown) +def iterate_names_and_labels( + item: Item, attributes: Generator[Attribute, None, None] +) -> Generator[tuple[str, str], None, None]: + """Get attribute names and labels. - get_parameter_labels = _get_parameter_labels_factory(cls) - setattr(cls, "get_parameter_labels", get_parameter_labels) + Parameters + ---------- + item: Item + The item. + attributes: Generator[Attribute, None, None] + The attributes. + + Yields + ------ + tuple[str, str] + The name and the label. + """ + for attr in attributes: + structure, _ = strip_type_and_structure_from_attribute(attr) + value = getattr(item, attr.name) + name: str = attr.metadata.get(META_ALIAS, attr.name) - return cls + if not value: + continue - return decorator + if structure is dict: + for v in value.values(): + yield name, v if isinstance(v, str) else (name, v.label) # type:ignore[misc] + elif structure is list: + for v in value: + yield name, v if isinstance(v, str) else (name, v.label) # type:ignore[misc] + + else: + yield name, value if isinstance(value, str) else ( + name, + value.label, # type:ignore[misc] + ) -def model_item_typed( - *, - types: dict[str, Any], - has_label: bool = True, - default_type: str = None, -): - """The model_item_typed decorator adds attributes to the class to enable - the glotaran model parser to infer the correct class for an item when there - are multiple variants. + +def iterate_model_item_names_and_labels(item: Item) -> Generator[tuple[str, str], None, None]: + """Get model item names and labels. Parameters ---------- - types : - A dictionary of types and options. - has_label: - If `False` no label property will be added. - """ - - def decorator(cls): + item: Item + The item. - setattr(cls, "_glotaran_model_item", True) - setattr(cls, "_glotaran_model_item_typed", True) - setattr(cls, "_glotaran_model_item_types", types) - setattr(cls, "_glotaran_model_item_default_type", default_type) + Yields + ------ + tuple[str, str] + The name and the label. + """ + yield from iterate_names_and_labels(item, model_attributes(item.__class__)) - get_default_type = _get_default_type_factory(cls) - setattr(cls, "get_default_type", get_default_type) - add_type = _add_type_factory(cls) - setattr(cls, "add_type", add_type) +def iterate_parameter_names_and_labels(item: Item) -> Generator[tuple[str, str], None, None]: + """Get parameter item names and labels. - setattr(cls, "_glotaran_has_label", has_label) + Parameters + ---------- + item: Item + The item. - return cls + Yields + ------ + tuple[str, str] + The name and the label. + """ + yield from iterate_names_and_labels(item, parameter_attributes(item.__class__)) - return decorator +def fill_item_attributes( + item: Item, + iterator: Iterator[Attribute], + fill_function: Callable[[str, str], Parameter | ModelItem], +): + """Fill item attributes. -def model_item_validator(need_parameter: bool): - """The model_item_validator marks a method of a model item as validation function""" + Parameters + ---------- + item: Item + The item. + iterator: Iterator[Attribute] + An iterator over attributes. + fill_function: Callable[[str, str], Parameter | ModelItem] + The function to fill the values. + """ + for attr in iterator: + value = getattr(item, attr.name) + if not value: + continue + + structure, _ = strip_type_and_structure_from_attribute(attr) + name = attr.metadata.get(META_ALIAS, attr.name) + if structure is dict: + value = { + k: fill_function(name, v) if isinstance(v, str) else fill_function(name, v.label) + for k, v in value.items() + } + elif structure is list: + value = [ + fill_function(name, v) if isinstance(v, str) else fill_function(name, v.label) + for v in value + ] + else: + value = ( + fill_function(name, value) + if isinstance(value, str) + else fill_function(name, value.label) + ) - def decorator(method: Validator | ValidatorParameter): - setattr(method, "_glotaran_validator", need_parameter) - return method + setattr(item, attr.name, value) - return decorator +def fill_item(item: ItemT, model: Model, parameters: Parameters) -> ItemT: + """Fill an item. -def _get_validators(cls): - return { - method: getattr(getattr(cls, method), "_glotaran_validator") - for method in dir(cls) - if hasattr(getattr(cls, method), "_glotaran_validator") - } + Parameters + ---------- + item: ItemT + The item. + model: Model + The model. + parameters: Parameters + The parameters. + + Returns + ------- + ItemT + The filled item. + """ + item = evolve(item) + fill_item_model_attributes(item, model, parameters) + fill_item_parameter_attributes(item, parameters) + return item -def _get_default_type_factory(cls): - @classmethod - @wrap_func_as_method(cls) - def get_default_type(cls) -> str: - return getattr(cls, "_glotaran_model_item_default_type") +def fill_item_model_attributes(item: Item, model: Model, parameters: Parameters): + """Fill item model attributes. - return get_default_type + Parameters + ---------- + item: Item + The item. + model: Model + The model. + parameters: Parameters + The parameters. + """ + fill_item_attributes( + item, + model_attributes(item.__class__), + lambda name, label: fill_item(getattr(model, name)[label], model, parameters), + ) -def _add_type_factory(cls): - @classmethod - @wrap_func_as_method(cls) - def add_type(cls, type_name: str, attribute_type: type): - getattr(cls, "_glotaran_model_item_types")[type_name] = attribute_type +def fill_item_parameter_attributes(item: Item, parameters: Parameters): + """Fill item parameter attributes. - return add_type + Parameters + ---------- + item: Item + The item. + parameters: Parameters + The parameters. + """ + fill_item_attributes( + item, parameter_attributes(item.__class__), lambda _, label: parameters.get(label) + ) -def _init_factory(cls): - @classmethod - @wrap_func_as_method(cls) - def __init__(self): - for attr in self._glotaran_properties: - setattr(self, f"_{attr}", None) +def get_item_model_issues(item: Item, model: Model) -> list[ItemIssue]: + """Get model item issues for an item. - return __init__ + Parameters + ---------- + item: Item + The item. + model: Model + The model. + + Returns + ------- + list[ItemIssue] + """ + return [ + ModelItemIssue(name, label) + for name, label in iterate_model_item_names_and_labels(item) + if label not in getattr(model, name) + ] -def _from_dict_factory(cls): - @classmethod - @wrap_func_as_method(cls) - def from_dict(ncls, values: dict) -> cls: - f"""Creates an instance of {cls.__name__} from a dictionary of values. +def get_item_parameter_issues(item: Item, parameters: Parameters) -> list[ItemIssue]: + """Get model item issues for an item. - Intended only for internal use. + Parameters + ---------- + item: Item + The item. + parameters: Parameters + The parameters. + + Returns + ------- + list[ItemIssue] + """ + return [ + ParameterIssue(label) + for name, label in iterate_parameter_names_and_labels(item) + if not parameters.has(label) + ] - Parameters - ---------- - values : - A list of values. - """ - item = ncls() - for name in ncls._glotaran_properties: - if name in values: - value = values[name] - prop = getattr(item.__class__, name) - if prop.glotaran_property_type == float: - value = float(value) - elif prop.glotaran_property_type == int: - value = int(value) - setattr(item, name, value) +def get_item_validator_issues( + item: Item, model: Model, parameters: Parameters | None = None +) -> list[ItemIssue]: + """Get validator issues for an item. - elif not getattr(ncls, name).glotaran_allow_none and getattr(item, name) is None: - raise ValueError(f"Missing Property '{name}' For Item '{ncls.__name__}'") - return item + Parameters + ---------- + item: Item + The item. + model: Model + The model. + parameters: Parameters | None + The parameters. + + Returns + ------- + list[ItemIssue] + """ + issues = [] + for name, validator in [ + (attr.name, attr.metadata[META_VALIDATOR]) + for attr in fields(item.__class__) + if META_VALIDATOR in attr.metadata + ]: + issues += validator(getattr(item, name), item, model, parameters) - return from_dict + return issues -def _validation_factory(cls): - @wrap_func_as_method(cls) - def validate(self, model: Model, parameters: ParameterGroup | None = None) -> list[str]: - f"""Creates a list of parameters needed by this instance of {cls.__name__} not present in a - set of parameters. +def get_item_issues( + *, item: Item, model: Model, parameters: Parameters | None = None +) -> list[ItemIssue]: + """Get issues for an item. - Parameters - ---------- - model : - The model to validate. - parameter : - The parameter to validate. - missing : - A list the missing will be appended to. - """ - problems = [] - for name in self._glotaran_properties: - prop = getattr(self.__class__, name) - value = getattr(self, name) - problems += prop.glotaran_validate(value, model, parameters) - for validator, need_parameter in self._glotaran_validators.items(): - if need_parameter: - if parameters is not None: - problems += getattr(self, validator)(model, parameters) - else: - problems += getattr(self, validator)(model) + Parameters + ---------- + item: Item + The item. + model: Model + The model. + parameters: Parameters | None + The parameters. + + Returns + ------- + list[ItemIssue] + """ + issues = get_item_model_issues(item, model) + issues += get_item_validator_issues(item, model, parameters) + if parameters is not None: + issues += get_item_parameter_issues(item, parameters) + return issues - return problems - return validate +def strip_type_and_structure_from_attribute(attr: Attribute) -> tuple[None | list | dict, type]: + """Strip the type and the structure from an attribute. + Parameters + ---------- + attr: Attribute + The attribute. -def _as_dict_factory(cls): - @wrap_func_as_method(cls) - def as_dict(self) -> dict: - return { - name: getattr(self.__class__, name).glotaran_replace_parameter_with_labels( - getattr(self, name) - ) - for name in self._glotaran_properties - if name != "label" and getattr(self, name) is not None - } + Returns + ------- + tuple[None | list | dict, type]: + The structure and the type. + """ + definition = attr.type + definition = strip_option_type(definition) + structure, definition = strip_structure_type(definition) + definition = strip_option_type(definition, strip_type=str) + return structure, definition - return as_dict +def strip_option_type(definition: type, strip_type: type = NoneType) -> type: + """Strip the type if the definition is an option. -def _fill_factory(cls): - @wrap_func_as_method(cls) - def fill(self, model: Model, parameters: ParameterGroup) -> cls: - f"""Returns a copy of the {cls.__name__} instance with all members which are Parameters are - replaced by the value of the corresponding parameter in the parameter group. + Parameters + ---------- + definition: type + The definition. + strip_type: type + The type which should be removed from the option. + + Returns + ------- + type + """ + args = list(get_args(definition)) + if get_origin(definition) in [Union, UnionType] and strip_type in args: + args.remove(strip_type) + definition = args[0] + return definition - Parameters - ---------- - model : - A glotaran model. - parameter : ParameterGroup - The parameter group to fill from. - """ - item = copy.deepcopy(self) - for name in self._glotaran_properties: - prop = getattr(self.__class__, name) - value = getattr(self, name) - value = prop.glotaran_fill(value, model, parameters) - setattr(item, name, value) - return item - - return fill - - -def _get_state_factory(cls): - @wrap_func_as_method(cls) - def get_state(self) -> cls: - return tuple(getattr(self, name) for name in self._glotaran_properties) - - return get_state - - -def _set_state_factory(cls): - @wrap_func_as_method(cls) - def set_state(self, state) -> cls: - for i, name in enumerate(self._glotaran_properties): - setattr(self, name, state[i]) - - return set_state - - -def _markdown_factory(cls): - @wrap_func_as_method(cls, name="markdown") - def mprint_item( - self, all_parameters: ParameterGroup = None, initial_parameters: ParameterGroup = None - ) -> MarkdownStr: - f"""Returns a string with the {cls.__name__} formatted in markdown.""" - - md = "\n" - if self._glotaran_has_label: - md = f"**{self.label}**" - if hasattr(self, "type"): - md += f" ({self.type})" - md += ":\n" - - elif hasattr(self, "type"): - md = f"**{self.type}**:\n" - - for name in self._glotaran_properties: - prop = getattr(self.__class__, name) - value = getattr(self, name) - if value is None: - continue - property_md = indent( - f"* *{name.replace('_', ' ').title()}*: " - f"{prop.glotaran_value_as_markdown(value,all_parameters, initial_parameters)}\n", - " ", - ) - md += property_md +def strip_structure_type(definition: type) -> tuple[None | list | dict, type]: + """Strip the structure from a definition. - return MarkdownStr(md) + Parameters + ---------- + definition: type + The definition. - return mprint_item + Returns + ------- + tuple[None | list | dict, type]: + The structure and the type. + """ + structure = get_origin(definition) + if structure is list: + definition = get_args(definition)[0] + elif structure is dict: + definition = get_args(definition)[1] + else: + structure = None + return structure, definition -def _get_parameter_labels_factory(cls): - @wrap_func_as_method(cls, name="get_parameter_labels") - def get_parameter_labels(self) -> list[str]: - parameter_labels = [] +def item(cls: type[ItemT]) -> type[ItemT]: + """Create an item from a class. - for name in self._glotaran_properties: - prop = getattr(self.__class__, name) - value = getattr(self, name) - parameter_labels += prop.glotaran_get_parameter_labels(value) + Parameters + ---------- + cls: type[ItemT] + The class. - return parameter_labels + Returns + ------- + type[ItemT] + """ + parent = getmro(cls)[1] + cls = define(kw_only=True, slots=False)(cls) + if parent in (TypedItem, ModelItemTyped): + assert issubclass(cls, TypedItem) + cls.__item_types__ = {} + elif issubclass(cls, TypedItem): + cls._register_item_class() + resolve_types(cls) + return cls + + +def attribute( + *, + alias: str | None = None, + default: Any = NOTHING, + factory: Callable[[], Any] = None, + validator: Callable[[Any, Item, Model, Parameters | None], list[ItemIssue]] | None = None, +) -> Attribute: + """Create an attribute for an item. - return get_parameter_labels + Parameters + ---------- + alias: str | None + The alias of the attribute (only useful for model items). + default: Any + The default value of the attribute. + factory: Callable + A factory function for the attribute. + validator: Callable[[Any, Item, Model, Parameters | None], list[ItemIssue]] | None + A validator function for the attribute. + + Returns + ------- + Attribute + """ + metadata: dict[str, Any] = {} + if alias is not None: + metadata[META_ALIAS] = alias + if validator is not None: + metadata[META_VALIDATOR] = validator + return field(default=default, factory=factory, metadata=metadata) diff --git a/glotaran/model/megacomplex.py b/glotaran/model/megacomplex.py index 10362d873..b84e96d16 100644 --- a/glotaran/model/megacomplex.py +++ b/glotaran/model/megacomplex.py @@ -1,113 +1,86 @@ +"""This module contains the megacomplex.""" from __future__ import annotations from typing import TYPE_CHECKING -from typing import Dict -from typing import List +from typing import Callable +from typing import ClassVar import numpy as np import xarray as xr +from attrs import NOTHING +from attrs import fields -from glotaran.model.item import model_item -from glotaran.model.item import model_item_typed -from glotaran.model.util import get_subtype -from glotaran.model.util import is_mapping_type -from glotaran.model.util import is_sequence_type +from glotaran.model.item import ModelItemTyped +from glotaran.model.item import item from glotaran.plugin_system.megacomplex_registration import register_megacomplex if TYPE_CHECKING: - from typing import Any from glotaran.model import DatasetModel -def create_model_megacomplex_type( - megacomplex_types: dict[str, Megacomplex], default_type: str = None -) -> type: - @model_item_typed(types=megacomplex_types, default_type=default_type) - class ModelMegacomplex: - """This class holds all Megacomplex types defined by a model.""" - - return ModelMegacomplex - - def megacomplex( *, - dimension: str | None = None, - model_items: dict[str, dict[str, Any]] = None, - properties: Any | dict[str, dict[str, Any]] = None, - dataset_model_items: dict[str, dict[str, Any]] = None, - dataset_properties: Any | dict[str, dict[str, Any]] = None, - unique: bool = False, + dataset_model_type: type[DatasetModel] | None = None, exclusive: bool = False, - register_as: str | None = None, -): - """The `@megacomplex` decorator is intended to be used on subclasses of - :class:`glotaran.model.Megacomplex`. It registers the megacomplex model - and makes it available in analysis models. + unique: bool = False, +) -> Callable: + """Create a megacomplex from a class. + + Parameters + ---------- + dataset_model_type: type + The dataset model type. + exclusive: bool + Whether the megacomplex is exclusive. + unique: bool + Whether the megacomplex is unique. + + Returns + ------- + Callable """ - properties = properties if properties is not None else {} - properties["dimension"] = {"type": str} - if dimension is not None: - properties["dimension"]["default"] = dimension - - if model_items is None: - model_items = {} - else: - model_items, properties = _add_model_items_to_properties(model_items, properties) - - dataset_properties = dataset_properties if dataset_properties is not None else {} - if dataset_model_items is None: - dataset_model_items = {} - else: - dataset_model_items, dataset_properties = _add_model_items_to_properties( - dataset_model_items, dataset_properties - ) def decorator(cls): - setattr(cls, "_glotaran_megacomplex_model_items", model_items) - setattr(cls, "_glotaran_megacomplex_dataset_model_items", dataset_model_items) - setattr(cls, "_glotaran_megacomplex_dataset_properties", dataset_properties) - setattr(cls, "_glotaran_megacomplex_unique", unique) - setattr(cls, "_glotaran_megacomplex_exclusive", exclusive) - - megacomplex_type = model_item(properties=properties, has_type=True)(cls) + megacomplex_type = item(cls) + megacomplex_type.__dataset_model_type__ = dataset_model_type + megacomplex_type.__is_exclusive__ = exclusive + megacomplex_type.__is_unique__ = unique - if register_as is not None: - megacomplex_type.name = register_as - register_megacomplex(register_as, megacomplex_type) + megacomplex_type_str = fields(cls).type.default + if megacomplex_type_str is not NOTHING: + register_megacomplex(megacomplex_type_str, megacomplex_type) return megacomplex_type return decorator -def _add_model_items_to_properties(model_items: dict, properties: dict) -> tuple[dict, dict]: - for name, item in model_items.items(): - item_type = item["type"] if isinstance(item, dict) else item - property_type = str - - if is_sequence_type(item_type): - property_type = List[str] - item_type = get_subtype(item_type) - elif is_mapping_type(item_type): - property_type = Dict[str, str] - item_type = get_subtype(item_type) - - property_dict = item.copy() if isinstance(item, dict) else {} - property_dict["type"] = property_type - properties[name] = property_dict - model_items[name] = item_type - return model_items, properties - - -class Megacomplex: +@item +class Megacomplex(ModelItemTyped): """A base class for megacomplex models. Subclasses must overwrite :method:`glotaran.model.Megacomplex.calculate_matrix` and :method:`glotaran.model.Megacomplex.index_dependent`. """ + dimension: str | None = None + + __dataset_model_type__: ClassVar[type | None] = None + __is_exclusive__: ClassVar[bool] + __is_unique__: ClassVar[bool] + + @classmethod + def get_dataset_model_type(cls) -> type | None: + """Get the dataset model type. + + Returns + ------- + type | None + """ + return cls.__dataset_model_type__ + def calculate_matrix( self, dataset_model: DatasetModel, @@ -115,10 +88,47 @@ def calculate_matrix( global_axis: np.typing.ArrayLike, model_axis: np.typing.ArrayLike, **kwargs, - ) -> xr.DataArray: + ) -> tuple[list[str], np.typing.ArrayLike]: + """Calculate the megacomplex matrix. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + global_index: int | None + The global index. + global_axis: np.typing.ArrayLike + The global axis. + model_axis: np.typing.ArrayLike, + The model axis. + **kwargs + Additional arguments. + + Returns + ------- + tuple[list[str], np.typing.ArrayLike]: + The clp labels and the matrix. + + .. # noqa: DAR202 + .. # noqa: DAR401 + """ raise NotImplementedError def index_dependent(self, dataset_model: DatasetModel) -> bool: + """Check if the megacomplex is index dependent. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + + Returns + ------- + bool + + .. # noqa: DAR202 + .. # noqa: DAR401 + """ raise NotImplementedError def finalize_data( @@ -128,24 +138,51 @@ def finalize_data( is_full_model: bool = False, as_global: bool = False, ): + """Finalize a dataset. + + Parameters + ---------- + dataset_model: DatasetModel + The dataset model. + dataset: xr.Dataset + The dataset. + is_full_model: bool + Whether the model is a full model. + as_global: bool + Whether megacomplex is calculated as global megacomplex. + + + .. # noqa: DAR101 + .. # noqa: DAR401 + """ raise NotImplementedError - @classmethod - def glotaran_model_items(cls) -> str: - return cls._glotaran_megacomplex_model_items - @classmethod - def glotaran_dataset_model_items(cls) -> str: - return cls._glotaran_megacomplex_dataset_model_items +def is_exclusive(cls: type[Megacomplex]) -> bool: + """Check if the megacomplex is exclusive. - @classmethod - def glotaran_dataset_properties(cls) -> str: - return cls._glotaran_megacomplex_dataset_properties + Parameters + ---------- + cls: type[Megacomplex] + The megacomplex type. - @classmethod - def glotaran_unique(cls) -> bool: - return cls._glotaran_megacomplex_unique + Returns + ------- + bool + """ + return cls.__is_exclusive__ - @classmethod - def glotaran_exclusive(cls) -> bool: - return cls._glotaran_megacomplex_exclusive + +def is_unique(cls: type[Megacomplex]) -> bool: + """Check if the megacomplex is unique. + + Parameters + ---------- + cls: type[Megacomplex] + The megacomplex type. + + Returns + ------- + bool + """ + return cls.__is_unique__ diff --git a/glotaran/model/model.py b/glotaran/model/model.py index ae7cc74e1..d521a1190 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -1,442 +1,472 @@ -"""A base class for global analysis models.""" +"""This module contains the model.""" from __future__ import annotations -import copy -from dataclasses import asdict from typing import Any -from typing import List -from warnings import warn +from typing import Callable +from typing import ClassVar +from typing import Generator +from typing import Mapping +from uuid import uuid4 + +from attr import asdict +from attr import fields +from attr import ib +from attrs import Attribute +from attrs import define +from attrs import make_class +from attrs import resolve_types -from glotaran.deprecation import raise_deprecation_error from glotaran.io import load_model -from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import Constraint +from glotaran.model.clp_constraint import ClpConstraint +from glotaran.model.clp_penalties import ClpPenalty +from glotaran.model.clp_relation import ClpRelation from glotaran.model.dataset_group import DatasetGroup from glotaran.model.dataset_group import DatasetGroupModel -from glotaran.model.dataset_model import create_dataset_model_type +from glotaran.model.dataset_model import DatasetModel +from glotaran.model.item import Item +from glotaran.model.item import ItemIssue +from glotaran.model.item import ModelItem +from glotaran.model.item import TypedItem +from glotaran.model.item import get_item_issues +from glotaran.model.item import item_to_markdown +from glotaran.model.item import iterate_parameter_names_and_labels +from glotaran.model.item import model_attributes +from glotaran.model.item import strip_type_and_structure_from_attribute from glotaran.model.megacomplex import Megacomplex -from glotaran.model.megacomplex import create_model_megacomplex_type -from glotaran.model.relation import Relation -from glotaran.model.util import ModelError from glotaran.model.weight import Weight from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup -from glotaran.plugin_system.megacomplex_registration import get_megacomplex +from glotaran.parameter import Parameters from glotaran.utils.ipython import MarkdownStr -default_model_items = { - "clp_area_penalties": EqualAreaPenalty, - "clp_constraints": Constraint, - "clp_relations": Relation, - "weights": Weight, -} +DEFAULT_DATASET_GROUP = "default" +META_ITEMS = "__glotaran_items__" +META = {META_ITEMS: True} -default_dataset_properties = { - "group": {"type": str, "default": "default"}, - "force_index_dependent": {"type": bool, "allow_none": True}, - "megacomplex": List[str], - "megacomplex_scale": {"type": List[Parameter], "allow_none": True}, - "global_megacomplex": {"type": List[str], "allow_none": True}, - "global_megacomplex_scale": {"type": List[Parameter], "default": None, "allow_none": True}, - "scale": {"type": Parameter, "default": None, "allow_none": True}, -} -root_parameter_error = ModelError( - "The root parameter group cannot contain both groups and parameters." -) +class ModelError(Exception): + """Raised when a model contains errors.""" - -class Model: - """A base class for global analysis models.""" - - loader = load_model - - def __init__( - self, - *, - megacomplex_types: dict[str, type[Megacomplex]], - default_megacomplex_type: str | None = None, - dataset_group_models: dict[str, DatasetGroupModel] = None, - ): - self._megacomplex_types = megacomplex_types - self._default_megacomplex_type = default_megacomplex_type or next(iter(megacomplex_types)) - - self._dataset_group_models = dataset_group_models or {"default": DatasetGroupModel()} - if "default" not in self._dataset_group_models: - self._dataset_group_models["default"] = DatasetGroupModel() - - self._model_items = {} - self._dataset_properties = {} - self._add_default_items_and_properties() - self._add_megacomplexe_types() - self._add_dataset_type() - self.source_path = "model.yml" - - @classmethod - def from_dict( - cls, - model_dict: dict[str, Any], - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ) -> Model: - """Creates a model from a dictionary. + def __init__(self, error: str): + """Create a model error. Parameters ---------- - model_dict: dict[str, Any] - Dictionary containing the model. - megacomplex_types: dict[str, type[Megacomplex]] | None - Overwrite 'megacomplex_types' in ``model_dict`` for testing. - default_megacomplex_type: str | None - Overwrite 'default_megacomplex' in ``model_dict`` for testing. + error: str + The error string. """ - model_dict = copy.deepcopy(model_dict) - if default_megacomplex_type is None: - default_megacomplex_type = model_dict.get("default_megacomplex") - - if megacomplex_types is None: - megacomplex_types = { - m["type"]: get_megacomplex(m["type"]) - for m in model_dict["megacomplex"].values() - if "type" in m - } - if ( - default_megacomplex_type is not None - and default_megacomplex_type not in megacomplex_types - ): - megacomplex_types[default_megacomplex_type] = get_megacomplex(default_megacomplex_type) - if "default_megacomplex" in model_dict: - model_dict.pop("default_megacomplex", None) - - dataset_group_models = model_dict.pop("dataset_groups", None) - if dataset_group_models is not None: - dataset_group_models = { - label: DatasetGroupModel(**group) for label, group in dataset_group_models.items() - } - - model = cls( - megacomplex_types=megacomplex_types, - default_megacomplex_type=default_megacomplex_type, - dataset_group_models=dataset_group_models, + super().__init__(f"ModelError: {error}") + + +def _load_item_from_dict( + item_type: type[Item], value: Item | Mapping, extra: dict[str, Any] | None = None +) -> Item: + """Load an item from a dictionary. + + Parameters + ---------- + item_type: type[Item] + The item type. + value: Item | dict + The value to load from. + extra: dict[str, Any] | None + Extra arguments for the item. + + Returns + ------- + Item + + Raises + ------ + ModelError + Raised if a modelitem is missing. + """ + if not isinstance(value, Item): + if extra: + value = value | extra + if issubclass(item_type, TypedItem): + try: + item_type = item_type.get_item_type_class(value["type"]) + except KeyError: + raise ModelError(f"Missing 'type' for item {item_type}") + return item_type(**(value)) + return value + + +def _load_model_items_from_dict( + item_type: type[Item], item_dict: Mapping[str, ModelItem | dict] +) -> dict[str, ModelItem]: + """Load a model items from a dictionary. + + Parameters + ---------- + item_type: type[Item] + The item type. + item_dict: dict[str, ModelItem | dict] + The item dictionary. + + Returns + ------- + dict[str, ModelItem] + """ + return { + label: _load_item_from_dict(item_type, value, extra={"label": label}) # type:ignore[misc] + for label, value in item_dict.items() + } + + +def _load_global_items_from_dict( + item_type: type[Item], item_list: list[Item | dict] +) -> list[Item]: + """Load an item from a dictionary. + + Parameters + ---------- + item_type: type[Item] + The item type. + item_list: list[Item | dict] + The list of item dicts. + + Returns + ------- + list[Item] + """ + return [_load_item_from_dict(item_type, value) for value in item_list] + + +def _load_dataset_groups( + dataset_groups: dict[str, DatasetGroupModel | Any] +) -> dict[str, DatasetGroupModel]: + """Add the default dataset group if not present. + + Parameters + ---------- + dataset_groups: dict[str, DatasetGroupModel] + The dataset groups. + + Returns + ------- + dict[str, DatasetGroupModel] + """ + dataset_group_items = _load_model_items_from_dict(DatasetGroupModel, dataset_groups) + if DEFAULT_DATASET_GROUP not in dataset_group_items: + dataset_group_items[DEFAULT_DATASET_GROUP] = DatasetGroupModel( + label=DEFAULT_DATASET_GROUP # type:ignore[call-arg] ) + return dataset_group_items # type:ignore[return-value] + + +def _global_item_attribute(item_type: type[Item]) -> Attribute: + """Create a global item attribute. + + Parameters + ---------- + item_type: type[Item] + The item type. + + Returns + ------- + Attribute + """ + return ib( + factory=list, + converter=lambda value: _load_global_items_from_dict(item_type, value), + metadata=META, + ) + + +def _model_item_attribute(item_type: type[ModelItem]): + """Create a model item attribute. + + Parameters + ---------- + item_type: type[ModelItem] + The item type. + + Returns + ------- + Attribute + """ + return ib( + type=dict[str, item_type], # type:ignore[valid-type] + factory=dict, + converter=lambda value: _load_model_items_from_dict(item_type, value), + metadata=META, + ) + + +def _create_attributes_for_item(item_type: type[Item]) -> dict[str, Attribute]: + """Create attributes for an item. + + Parameters + ---------- + item_type: type[Item] + The item type. + + Returns + ------- + dict[str, Attribute] + """ + attributes = {} + for model_item in model_attributes(item_type, with_alias=False): + _, model_item_type = strip_type_and_structure_from_attribute(model_item) + attributes[model_item.name] = _model_item_attribute(model_item_type) + return attributes + + +@define(kw_only=True) +class Model: + """A model for global target analysis.""" - # iterate over items - for item_name, items in list(model_dict.items()): - - if item_name not in model.model_items: - warn(f"Unknown model item type '{item_name}'.") - continue + loader: ClassVar[Callable] = load_model - if isinstance(getattr(model, item_name), list): - model._add_list_items(item_name, items) - else: - model._add_dict_items(item_name, items) - - return model - - def _add_dict_items(self, item_name: str, items: dict): - - for label, item in items.items(): - item_cls = self.model_items[item_name] - if hasattr(item_cls, "_glotaran_model_item_typed"): - if "type" not in item and item_cls.get_default_type() is None: - raise ValueError(f"Missing type for attribute '{item_name}'") - item_type = item.get("type", item_cls.get_default_type()) - - types = item_cls._glotaran_model_item_types - if item_type not in types: - raise ValueError(f"Unknown type '{item_type}' for attribute '{item_name}'") - item_cls = types[item_type] - item["label"] = label - item = item_cls.from_dict(item) - getattr(self, item_name)[label] = item - - def _add_list_items(self, item_name: str, items: list): - - for item in items: - item_cls = self.model_items[item_name] - if hasattr(item_cls, "_glotaran_model_item_typed"): - if "type" not in item: - raise ValueError(f"Missing type for attribute '{item_name}'") - item_type = item["type"] - - if item_type not in item_cls._glotaran_model_item_types: - raise ValueError(f"Unknown type '{item_type}' for attribute '{item_name}'") - item_cls = item_cls._glotaran_model_item_types[item_type] - item = item_cls.from_dict(item) - getattr(self, item_name).append(item) - - def _add_megacomplexe_types(self): - - for megacomplex_name, megacomplex_type in self._megacomplex_types.items(): - if not issubclass(megacomplex_type, Megacomplex): - raise TypeError( - f"Megacomplex type {megacomplex_name}({megacomplex_type}) " - "is not a subclass of Megacomplex" - ) - self._add_megacomplex_type(megacomplex_type) + source_path: str | None = ib(default=None, init=False, repr=False) + clp_penalties: list[ClpPenalty] = _global_item_attribute(ClpPenalty) + clp_constraints: list[ClpConstraint] = _global_item_attribute(ClpConstraint) + clp_relations: list[ClpRelation] = _global_item_attribute(ClpRelation) - model_megacomplex_type = create_model_megacomplex_type( - self._megacomplex_types, self.default_megacomplex - ) - self._add_model_item("megacomplex", model_megacomplex_type) + dataset_groups: dict[str, DatasetGroupModel] = ib( + factory=dict, converter=_load_dataset_groups, metadata=META + ) - def _add_megacomplex_type(self, megacomplex_type: type[Megacomplex]): + dataset: dict[str, DatasetModel] - for item_name, item in megacomplex_type.glotaran_model_items().items(): - self._add_model_item(item_name, item) + megacomplex: dict[str, Megacomplex] = ib( + factory=dict, + converter=lambda value: _load_model_items_from_dict(Megacomplex, value), + metadata=META, + ) - for item_name, item in megacomplex_type.glotaran_dataset_model_items().items(): - self._add_model_item(item_name, item) + weights: list[Weight] = _global_item_attribute(Weight) - for property_name, prop in megacomplex_type.glotaran_dataset_properties().items(): - self._add_dataset_property(property_name, prop) + @classmethod + def create_class(cls, attributes: dict[str, Attribute]) -> type[Model]: + """Create model class. - def _add_model_item(self, item_name: str, item: type): - if item_name in self._model_items: - if self.model_items[item_name] != item: - raise ModelError( - f"Cannot add item of type {item_name}. Model item '{item_name}' " - "was already defined as a different type." - ) - return - self._model_items[item_name] = item + Parameters + ---------- + attributes: dict[str, Attribute] + The model attributes. - if getattr(item, "_glotaran_has_label"): - setattr(self, f"{item_name}", {}) - else: - setattr(self, f"{item_name}", []) - - def _add_dataset_property(self, property_name: str, dataset_property: dict[str, any]): - if property_name in self._dataset_properties: - known_type = ( - self._dataset_properties[property_name]["type"] - if isinstance(self._dataset_properties, dict) - else self._dataset_properties[property_name] - ) + Returns + ------- + type[Model] + """ + cls_name = f"GlotaranModel_{str(uuid4()).replace('-','_')}" + return make_class(cls_name, attributes, bases=(cls,)) - new_type = ( - dataset_property["type"] - if isinstance(dataset_property, dict) - else dataset_property - ) + @classmethod + def create_class_from_megacomplexes( + cls, megacomplexes: list[type[Megacomplex]] + ) -> type[Model]: + """Create model class for megacomplexes. - if known_type != new_type: - raise ModelError( - f"Cannot add dataset property of type {property_name} as it was " - "already defined as a different type." - ) - return - self._dataset_properties[property_name] = dataset_property - - def _add_default_items_and_properties(self): - for item_name, item in default_model_items.items(): - self._add_model_item(item_name, item) - - for property_name, prop in default_dataset_properties.items(): - self._add_dataset_property(property_name, prop) - - def _add_dataset_type(self): - dataset_model_type = create_dataset_model_type(self._dataset_properties) - self._add_model_item("dataset", dataset_model_type) - - @property - def model_dimension(self): - """Deprecated use ``Scheme.model_dimensions['']`` instead""" - raise_deprecation_error( - deprecated_qual_name_usage="Model.model_dimension", - new_qual_name_usage=("Scheme.model_dimensions['']"), - to_be_removed_in_version="0.7.0", - ) + Parameters + ---------- + megacomplexes: list[type[Megacomplex]] + The megacomplexes. - @property - def global_dimension(self): - """Deprecated use ``Scheme.global_dimensions['']`` instead""" - raise_deprecation_error( - deprecated_qual_name_usage="Model.global_dimension", - new_qual_name_usage=("Scheme.global_dimensions['']"), - to_be_removed_in_version="0.7.0", + Returns + ------- + type[Model] + """ + attributes: dict[str, Attribute] = {} + dataset_types = set() + for megacomplex in megacomplexes: + if dataset_model_type := megacomplex.get_dataset_model_type(): + dataset_types |= { + dataset_model_type, + } + attributes.update(_create_attributes_for_item(megacomplex)) + + dataset_type = ( + DatasetModel + if len(dataset_types) == 0 + else make_class( + f"GlotaranDataset_{str(uuid4()).replace('-','_')}", + [], + bases=tuple(dataset_types), + collect_by_mro=True, + ) ) + resolve_types(dataset_type) - @property - def default_megacomplex(self) -> str: - """The default megacomplex used by this model.""" - return self._default_megacomplex_type + attributes.update(_create_attributes_for_item(dataset_type)) - @property - def megacomplex_types(self) -> dict[str, type[Megacomplex]]: - """The megacomplex types used by this model.""" - return self._megacomplex_types + attributes["dataset"] = _model_item_attribute(dataset_type) - @property - def dataset_group_models(self) -> dict[str, DatasetGroupModel]: - return self._dataset_group_models + return cls.create_class(attributes) - @property - def model_items(self) -> dict[str, type[object]]: - """The model_items types used by this model.""" - return self._model_items + def as_dict(self) -> dict: + """Get the model as dictionary. - @property - def global_megacomplex(self) -> dict[str, Megacomplex]: - """Alias for `glotaran.model.megacomplex`. Needed internally.""" - return self.megacomplex + Returns + ------- + dict + """ + return asdict( + self, + recurse=True, + retain_collection_types=True, + filter=lambda attr, _: attr.name != "source_path", + ) def get_dataset_groups(self) -> dict[str, DatasetGroup]: + """Get the dataset groups. + + Returns + ------- + dict[str, DatasetGroup] + + Raises + ------ + ModelError + Raised if a dataset group is unknown. + """ groups = {} for dataset_model in self.dataset.values(): group = dataset_model.group if group not in groups: try: - group_model = self.dataset_group_models[group] + group_model = self.dataset_groups[group] except KeyError: - raise ValueError(f"Unknown dataset group '{group}'") + raise ModelError(f"Unknown dataset group '{group}'") groups[group] = DatasetGroup( - residual_function=group_model.residual_function, - link_clp=group_model.link_clp, - model=self, + residual_function=group_model.residual_function, # type:ignore[call-arg] + link_clp=group_model.link_clp, # type:ignore[call-arg] + model=self, # type:ignore[call-arg] ) groups[group].dataset_models[dataset_model.label] = dataset_model return groups - def as_dict(self) -> dict: - model_dict = { - "default_megacomplex": self.default_megacomplex, - "dataset_groups": { - label: asdict(group) for label, group in self.dataset_group_models.items() - }, + def iterate_items(self) -> Generator[tuple[str, dict[str, Item] | list[Item]], None, None]: + """Iterate items. + + Yields + ------ + tuple[str, dict[str, Item] | list[Item]] + The name of the item and the individual items of the type. + """ + for attr in fields(self.__class__): + if META_ITEMS in attr.metadata: + yield attr.name, getattr(self, attr.name) # type:ignore[misc] + + def iterate_all_items(self) -> Generator[Item, None, None]: + """Iterate the individual items. + + Yields + ------ + Item + The individual item. + """ + for _, items in self.iterate_items(): + yield from items.values() if isinstance(items, dict) else items + + def get_parameter_labels(self) -> set[str]: + """Get all parameter labels. + + Returns + ------- + set[str] + """ + return { + label + for item in self.iterate_all_items() + for _, label in iterate_parameter_names_and_labels(item) } - for item_name in self._model_items: - items = getattr(self, item_name) - if len(items) == 0: - continue - if isinstance(items, list): - model_dict[item_name] = [item.as_dict() for item in items] - else: - model_dict[item_name] = {label: item.as_dict() for label, item in items.items()} - - return model_dict - - def get_parameter_labels(self) -> list[str]: - parameter_labels = [] - for item_name in self.model_items: - items = getattr(self, item_name) - item_iterator = items if isinstance(items, list) else items.values() - for item in item_iterator: - parameter_labels += item.get_parameter_labels() - return parameter_labels - - def generate_parameters(self) -> dict | list: - parameters: dict | list = {} - for parameter in self.get_parameter_labels(): - groups = parameter.split(".") - label = groups.pop() - if len(groups) == 0: - if isinstance(parameters, dict): - if len(parameters) != 0: - raise root_parameter_error - else: - parameters = [] - parameters.append(Parameter.create_default_list(label)) - else: - if isinstance(parameters, list): - raise root_parameter_error - this_group = groups.pop() - group = parameters - for name in groups: - if name not in group: - group[name] = {} - group = group[name] - if this_group not in group: - group[this_group] = [] - group[this_group].append(Parameter.create_default_list(label)) - return parameters - - def need_index_dependent(self) -> bool: - """Returns true if e.g. clp_relations with intervals are present.""" - return any(i.interval is not None for i in self.clp_constraints + self.clp_relations) - - def problem_list(self, parameters: ParameterGroup | None = None) -> list[str]: + + def generate_parameters(self) -> Parameters: + """Generate parameters for the model. + + Returns + ------- + Parameters + The generated parameters. + + .. # noqa: D414 """ - Returns a list with all problems in the model and missing parameters if specified. + return Parameters( + {label: Parameter(label=label, value=0) for label in self.get_parameter_labels()} + ) + + def get_issues(self, *, parameters: Parameters | None = None) -> list[ItemIssue]: + """Get issues. Parameters ---------- + parameters: Parameters | None + The parameters. - parameter : - The parameter to validate. + Returns + ------- + list[ItemIssue] """ - problems = [] - - for name in self.model_items: - items = getattr(self, name) - if isinstance(items, list): - for item in items: - problems += item.validate(self, parameters=parameters) - else: - for item in items.values(): - problems += item.validate(self, parameters=parameters) - - if parameters is not None and len(parameters.missing_parameter_value_labels) != 0: - label_prefix = "\n - " - problems.append( - f"Parameter definition is missing values for the labels:" - f"{label_prefix}{label_prefix.join(parameters.missing_parameter_value_labels)}" - ) - - return problems + issues = [] + for item in self.iterate_all_items(): + issues += get_item_issues(item=item, model=self, parameters=parameters) + return issues def validate( - self, parameters: ParameterGroup = None, raise_exception: bool = False + self, parameters: Parameters | None = None, raise_exception: bool = False ) -> MarkdownStr: - """ - Returns a string listing all problems in the model and missing parameters if specified. + """Get a string listing all issues in the model and missing parameters if specified. Parameters ---------- - - parameter : - The parameter to validate. + parameters: Parameters | None + The parameters. + raise_exception: bool + Whether to raise an exception on failed validation. + + Returns + ------- + MarkdownStr + + Raises + ------ + ModelError + Raised if validation fails and raise_exception is true. """ result = "" - if problems := self.problem_list(parameters): - result = f"Your model has {len(problems)} problem{'s' if len(problems) > 1 else ''}:\n" - for p in problems: - result += f"\n * {p}" + if issues := self.get_issues(parameters=parameters): + result = f"Your model has {len(issues)} problem{'s' if len(issues) > 1 else ''}:\n" + for issue in issues: + result += f"\n * {issue.to_string()}" if raise_exception: raise ModelError(result) else: result = "Your model is valid." return MarkdownStr(result) - def valid(self, parameters: ParameterGroup = None) -> bool: - """Returns `True` if the number problems in the model is 0, else `False` + def valid(self, parameters: Parameters | None = None) -> bool: + """Check if the model is valid. Parameters ---------- + parameters: Parameters | None + The parameters. - parameter : - The parameter to validate. + Returns + ------- + bool """ - return len(self.problem_list(parameters)) == 0 + return len(self.get_issues(parameters=parameters)) == 0 def markdown( self, - parameters: ParameterGroup = None, - initial_parameters: ParameterGroup = None, + parameters: Parameters = None, + initial_parameters: Parameters = None, base_heading_level: int = 1, ) -> MarkdownStr: - """Formats the model as Markdown string. + """Format the model as Markdown string. Parameters will be included if specified. Parameters ---------- - parameter: ParameterGroup + parameters: Parameters Parameter to include. - initial_parameters: ParameterGroup + initial_parameters: Parameters Initial values for the parameters. base_heading_level: int Base heading level of the markdown sections. @@ -445,44 +475,38 @@ def markdown( - If it is 1 the string will start with '# Model'. - If it is 3 the string will start with '### Model'. + + Returns + ------- + MarkdownStr """ base_heading = "#" * base_heading_level string = f"{base_heading} Model\n\n" - string += "_Megacomplex Types_: " - string += ", ".join(self._megacomplex_types) - string += "\n\n" - - string += f"{base_heading}# Dataset Groups\n\n" - for group_name, group in self.dataset_group_models.items(): - string += f"* **{group_name}**:\n" - string += f" * *Label*: {group_name}\n" - for item_name, item_value in asdict(group).items(): - string += f" * *{item_name}*: {item_value}\n" - - string += "\n" - for name in self.model_items: - items = getattr(self, name) + for name, items in self.iterate_items(): if not items: continue string += f"{base_heading}# {name.replace('_', ' ').title()}\n\n" if isinstance(items, dict): - items = items.values() + items = items.values() # type:ignore[assignment] for item in items: - item_str = item.markdown( - all_parameters=parameters, initial_parameters=initial_parameters + assert isinstance(item, Item) + item_str = item_to_markdown( + item, parameters=parameters, initial_parameters=initial_parameters ).split("\n") - string += f"* {item_str[0]}\n" + string += f"* **{getattr(item, 'label', '')}**\n" for s in item_str[1:]: string += f" {s}\n" string += "\n" return MarkdownStr(string) def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown.""" - return str(self.markdown(base_heading_level=3)) + """Render ``ipython`` markdown. - def __str__(self) -> str: - return str(self.markdown()) + Returns + ------- + str + """ + return str(self.markdown(base_heading_level=3)) diff --git a/glotaran/model/model.pyi b/glotaran/model/model.pyi deleted file mode 100644 index b61481518..000000000 --- a/glotaran/model/model.pyi +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Any - -from _typeshed import Incomplete - -from glotaran.deprecation import raise_deprecation_error -from glotaran.io import load_model -from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import Constraint -from glotaran.model.dataset_group import DatasetGroup -from glotaran.model.dataset_group import DatasetGroupModel -from glotaran.model.dataset_model import DatasetModel -from glotaran.model.dataset_model import create_dataset_model_type -from glotaran.model.megacomplex import Megacomplex -from glotaran.model.megacomplex import create_model_megacomplex_type -from glotaran.model.relation import Relation -from glotaran.model.util import ModelError -from glotaran.model.weight import Weight -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup -from glotaran.plugin_system.megacomplex_registration import get_megacomplex -from glotaran.utils.ipython import MarkdownStr - -default_model_items: Incomplete -default_dataset_properties: Incomplete -root_parameter_error: Incomplete - -class Model: - loader: Incomplete - source_path: str - def __init__( - self, - *, - megacomplex_types: dict[str, type[Megacomplex]], - default_megacomplex_type: str | None = ..., - dataset_group_models: dict[str, DatasetGroupModel] = ..., - ) -> None: ... - @classmethod - def from_dict( - cls, - model_dict: dict[str, Any], - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = ..., - default_megacomplex_type: str | None = ..., - ) -> Model: ... - @property - def model_dimension(self) -> None: ... - @property - def global_dimension(self) -> None: ... - @property - def default_megacomplex(self) -> str: ... - @property - def megacomplex_types(self) -> dict[str, type[Megacomplex]]: ... - @property - def dataset_group_models(self) -> dict[str, DatasetGroupModel]: ... - @property - def model_items(self) -> dict[str, type[object]]: ... - @property - def global_megacomplex(self) -> dict[str, Megacomplex]: ... - def get_dataset_groups(self) -> dict[str, DatasetGroup]: ... - def as_dict(self) -> dict: ... - def get_parameter_labels(self) -> list[str]: ... - def generate_parameters(self) -> dict | list: ... - def need_index_dependent(self) -> bool: ... - def problem_list(self, parameters: ParameterGroup | None = ...) -> list[str]: ... - def validate( - self, parameters: ParameterGroup = ..., raise_exception: bool = ... - ) -> MarkdownStr: ... - def valid(self, parameters: ParameterGroup = ...) -> bool: ... - def markdown( - self, - parameters: ParameterGroup = ..., - initial_parameters: ParameterGroup = ..., - base_heading_level: int = ..., - ) -> MarkdownStr: ... - @property - def clp_area_penalties(self) -> list[EqualAreaPenalty]: ... - @property - def clp_constraints(self) -> dict[str, Constraint]: ... - @property - def clp_relations(self) -> dict[str, Relation]: ... - @property - def dataset(self) -> dict[str, DatasetModel]: ... - @property - def weights(self) -> dict[str, Weight]: ... diff --git a/glotaran/model/property.py b/glotaran/model/property.py deleted file mode 100644 index 25a32e82f..000000000 --- a/glotaran/model/property.py +++ /dev/null @@ -1,433 +0,0 @@ -"""This module holds the model property class.""" -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Mapping -from typing import Sequence -from typing import TypeVar - -from glotaran.model.util import get_subtype -from glotaran.model.util import is_mapping_type -from glotaran.model.util import is_scalar_type -from glotaran.model.util import is_sequence_type -from glotaran.model.util import wrap_func_as_method -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup -from glotaran.utils.ipython import MarkdownStr - -if TYPE_CHECKING: - from glotaran.model.model import Model - - -ParameterOrLabel = TypeVar("ParameterOrLabel", str, Parameter) - - -class ModelProperty(property): - """ModelProperty is an extension of the property decorator. - - It adds convenience functions for meta programming model items. - """ - - def __init__( - self, cls: type, name: str, property_type: type, doc: str, default: Any, allow_none: bool - ): - """Create a new model property. - - Parameters - ---------- - cls : type - The class the property is being attached to. - name : str - The name of the property. - property_type : type - The type of the property. - doc : str - A documentation string of for the property. - default : Any - The default value of the property. - allow_none : bool - Whether the property is allowed to be None. - """ - self._name = name - self._allow_none = allow_none - self._default = default - - if get_subtype(property_type) is Parameter: - if is_scalar_type(property_type): - property_type = ParameterOrLabel # type: ignore[assignment] - elif is_sequence_type(property_type): - property_type = Sequence[ParameterOrLabel] - elif is_mapping_type(property_type): - property_type = Mapping[ - property_type.__args__[0], ParameterOrLabel # type: ignore[name-defined] - ] - - self._type = property_type - - super().__init__( - fget=_model_property_getter_factory(cls, self), - fset=_model_property_setter_factory(cls, self), - doc=doc, - ) - - @property - def glotaran_allow_none(self) -> bool: - """Check if the property is allowed to be None. - - Returns - ------- - bool - Whether the property is allowed to be None. - """ - return self._allow_none - - @property - def glotaran_property_type(self) -> type: - """Get the type of the property. - - Returns - ------- - type - The type of the property. - """ - return self._type - - @property - def glotaran_is_scalar_property(self) -> bool: - """Check if the type is scalar. - - Scalar means the type is neither a sequence nor a mapping. - - Returns - ------- - bool - Whether the type is scalar. - """ - return is_scalar_type(self._type) - - @property - def glotaran_is_sequence_property(self) -> bool: - """Check if the type is a sequence. - - Returns - ------- - bool - Whether the type is a sequence. - """ - return is_sequence_type(self._type) - - @property - def glotaran_is_mapping_property(self) -> bool: - """Check if the type is mapping. - - Returns - ------- - bool - Whether the type is a mapping. - """ - return is_mapping_type(self._type) - - @property - def glotaran_property_subtype(self) -> type: - """Get the subscribed type. - - If the type is scalar, the type itself will be returned. If the type is a mapping, - the value type will be returned. - - Returns - ------- - type - The subscribed type. - """ - return get_subtype(self._type) - - @property - def glotaran_is_parameter_property(self) -> bool: - """Check if the subtype is parameter. - - Returns - ------- - bool - Whether the subtype is parameter. - """ - return self.glotaran_property_subtype is ParameterOrLabel - - def glotaran_replace_parameter_with_labels(self, value: Any) -> Any: - """Replace parameter values with their full label. - - A convenience function for serialization. - - Parameters - ---------- - value : Any - The value to replace. - - Returns - ------- - Any - The value with parameters replaced by their labels. - """ - if not self.glotaran_is_parameter_property or value is None: - return value - elif self.glotaran_is_scalar_property: - return value.full_label - elif self.glotaran_is_sequence_property: - return [v.full_label for v in value] - elif self.glotaran_is_mapping_property: - return {k: v.full_label for k, v in value.items()} - - def glotaran_validate( - self, value: Any, model: Model, parameters: ParameterGroup = None - ) -> list[str]: - """Validate a value against a model and optionally against parameters. - - Parameters - ---------- - value : Any - The value to validate. - model : Model - The model to validate against. - parameters : ParameterGroup - The parameters to validate against. - - Returns - ------- - list[str] - A list of human readable list of messages of problems. - """ - if value is None: - if self.glotaran_allow_none: - return [] - else: - return [f"Property '{self._name}' is none but not allowed to be none."] - - missing_model: list[tuple[str, str]] = [] - if self._name in model.model_items: - items = getattr(model, self._name) - - if self.glotaran_is_sequence_property: - missing_model.extend((self._name, item) for item in value if item not in items) - elif self.glotaran_is_mapping_property: - missing_model.extend( - (self._name, item) for item in value.values() if item not in items - ) - elif value not in items: - missing_model.append((self._name, value)) - missing_model_messages = [ - f"Missing Model Item: '{name}'['{label}']" for name, label in missing_model - ] - - missing_parameters: list[str] = [] - if parameters is not None and self.glotaran_is_parameter_property: - wanted = value - if self.glotaran_is_scalar_property: - wanted = [wanted] - elif self.glotaran_is_mapping_property: - wanted = wanted.values() - missing_parameters.extend( - parameter.full_label - for parameter in wanted - if not parameters.has(parameter.full_label) - ) - missing_parameters_messages = [f"Missing Parameter: '{p}'" for p in missing_parameters] - - return missing_model_messages + missing_parameters_messages - - def glotaran_fill(self, value: Any, model: Model, parameter: ParameterGroup) -> Any: - """Fill a property with items from a model and parameters. - - This replaces model item labels with the actual items and sets the parameter values. - - Parameters - ---------- - value : Any - The property value. - model : Model - The model to fill in. - parameter : ParameterGroup - The parameters to fill in. - - Returns - ------- - Any - The filled value. - """ - if value is None: - return None - - if self.glotaran_is_scalar_property: - if self.glotaran_is_parameter_property: - value.set_from_group(parameter) - elif hasattr(model, self._name) and not isinstance(value, bool): - value = getattr(model, self._name)[value].fill(model, parameter) - - elif self.glotaran_is_sequence_property: - if self.glotaran_is_parameter_property: - for v in value: - v.set_from_group(parameter) - elif hasattr(model, self._name): - value = [getattr(model, self._name)[v].fill(model, parameter) for v in value] - - elif self.glotaran_is_mapping_property: - if self.glotaran_is_parameter_property: - for v in value.values(): - v.set_from_group(parameter) - elif hasattr(model, self._name): - value = { - k: getattr(model, self._name)[v].fill(model, parameter) - for (k, v) in value.items() - } - - return value - - def glotaran_value_as_markdown( - self, - value: Any, - all_parameters: ParameterGroup | None = None, - initial_parameters: ParameterGroup | None = None, - ) -> MarkdownStr: - """Get a markdown representation of the property. - - Parameters - ---------- - value : Any - The property value. - all_parameters : ParameterGroup | None - A parameter group containing the whole parameter set (used for expression lookup). - initial_parameters : ParameterGroup | None - The initial parameter. - - Returns - ------- - MarkdownStr - The property as markdown string. - """ - md = "" - if self.glotaran_is_scalar_property: - md = self.glotaran_format_value(value, all_parameters, initial_parameters) - elif self.glotaran_is_sequence_property: - for v in value: - md += f"\n * {self.glotaran_format_value(v,all_parameters, initial_parameters)}" - elif self.glotaran_is_mapping_property: - for k, v in value.items(): - md += ( - f"\n * {k}: " - f"{self.glotaran_format_value(v,all_parameters, initial_parameters)}" - ) - return MarkdownStr(md) - - def glotaran_format_value( - self, - value: Any, - all_parameters: ParameterGroup | None = None, - initial_parameters: ParameterGroup | None = None, - ) -> str: - """Format a value to string. - - Parameters - ---------- - value : Any - The value to format. - all_parameters : ParameterGroup | None - A parameter group containing the whole parameter set (used for expression lookup). - initial_parameters : ParameterGroup | None - The initial parameter. - - Returns - ------- - str - The formatted value. - """ - return ( - value.markdown(all_parameters, initial_parameters) - if self.glotaran_is_parameter_property - else str(value) - ) - - def glotaran_get_parameter_labels(self, value: Any) -> list[str]: - """Get a list of all parameter labels if the property is parameter. - - Parameters - ---------- - value : Any - The value of the property. - - Returns - ------- - list[str] - The list of full parameter labels. - """ - if value is None or not self.glotaran_is_parameter_property: - return [] - elif self.glotaran_is_sequence_property: - return [v.full_label for v in value] - elif self.glotaran_is_mapping_property: - return [v.full_label for v in value.values()] - return [value.full_label] - - -def _model_property_getter_factory(cls: type, model_property: ModelProperty) -> Callable: - """Create a getter function for model property. - - Parameters - ---------- - cls: type - The class to create the getter for. - model_property : ModelProperty - The property to create the getter for. - - Returns - ------- - Callable - The created getter. - """ - - @wrap_func_as_method(cls, name=model_property._name) - def getter(self) -> model_property.glotaran_property_type: # type: ignore[name-defined] - value = getattr(self, f"_{model_property._name}") - if value is None: - value = model_property._default - return value - - return getter - - -def _model_property_setter_factory(cls: type, model_property: ModelProperty): - """Create a setter function for model property. - - Parameters - ---------- - cls: type - The class to create the setter for. - model_property : ModelProperty - The property to create the setter for. - - Returns - ------- - Callable - The created setter. - """ - - @wrap_func_as_method(cls, name=model_property._name) - def setter(self, value: model_property.glotaran_property_type): # type: ignore[name-defined] - if value is None and not model_property._allow_none: - raise ValueError( - f"Property '{model_property._name}' of '{cls.__name__}' " - "is not allowed to set to None." - ) - if value is not None and model_property.glotaran_is_parameter_property: - if model_property.glotaran_is_scalar_property and not isinstance(value, Parameter): - value = Parameter(full_label=str(value)) - elif model_property.glotaran_is_sequence_property and all( - not isinstance(v, Parameter) for v in value - ): - value = [Parameter(full_label=str(v)) for v in value] - elif model_property.glotaran_is_mapping_property and all( - not isinstance(v, Parameter) for v in value.values() - ): - value = {k: Parameter(full_label=str(v)) for k, v in value.items()} - setattr(self, f"_{model_property._name}", value) - - return setter diff --git a/glotaran/model/relation.py b/glotaran/model/relation.py deleted file mode 100644 index d843829eb..000000000 --- a/glotaran/model/relation.py +++ /dev/null @@ -1,21 +0,0 @@ -""" Glotaran Relation """ -from __future__ import annotations - -from glotaran.model.interval_property import IntervalProperty -from glotaran.model.item import model_item -from glotaran.parameter import Parameter - - -@model_item( - properties={ - "source": str, - "target": str, - "parameter": Parameter, - }, - has_label=False, -) -class Relation(IntervalProperty): - """Applies a relation between clps as - - :math:`target = parameter * source`. - """ diff --git a/glotaran/model/test/test_dataset_model.py b/glotaran/model/test/test_dataset_model.py index ac89ab74d..43b59432e 100644 --- a/glotaran/model/test/test_dataset_model.py +++ b/glotaran/model/test/test_dataset_model.py @@ -3,86 +3,109 @@ import pytest -from glotaran.builtin.megacomplexes.baseline import BaselineMegacomplex -from glotaran.builtin.megacomplexes.coherent_artifact import CoherentArtifactMegacomplex -from glotaran.builtin.megacomplexes.damped_oscillation import DampedOscillationMegacomplex -from glotaran.builtin.megacomplexes.decay import DecayMegacomplex -from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex -from glotaran.model.dataset_model import create_dataset_model_type -from glotaran.model.model import default_dataset_properties - - -class MockModel: - """Test Model only containing the megacomplex property. - - Multiple and different kinds of megacomplexes are defined - but only a subset will be used by the DatsetModel. - """ - - def __init__(self) -> None: - self.megacomplex = { - # not unique - "d1": DecayMegacomplex(), - "d2": DecayMegacomplex(), - "d3": DecayMegacomplex(), - "s1": SpectralMegacomplex(), - "s2": SpectralMegacomplex(), - "s3": SpectralMegacomplex(), - "doa1": DampedOscillationMegacomplex(), - "doa2": DampedOscillationMegacomplex(), - # unique - "b1": BaselineMegacomplex(), - "b2": BaselineMegacomplex(), - "c1": CoherentArtifactMegacomplex(), - "c2": CoherentArtifactMegacomplex(), - } - - -@pytest.mark.parametrize( - "used_megacomplexes, expected_problems", - ( - ( - ["d1"], - [], - ), - ( - ["d1", "d2", "d3"], - [], - ), - ( - ["s1", "s2", "s3"], - [], - ), - ( - ["d1", "d2", "d3", "s1", "s2", "s3", "doa1", "doa2", "b1", "c1"], - [], - ), - ( - ["d1", "b1", "b2"], - ["Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'"], - ), - ( - ["d1", "c1", "c2"], - ["Multiple instances of unique megacomplex type 'coherent-artifact' in dataset 'ds1'"], - ), - ( - ["d1", "b1", "b2", "c1", "c2"], - [ - "Multiple instances of unique megacomplex type 'baseline' in dataset 'ds1'", - "Multiple instances of unique megacomplex type " - "'coherent-artifact' in dataset 'ds1'", - ], - ), - ), -) -def test_datasetmodel_ensure_unique_megacomplexes( - used_megacomplexes: list[str], expected_problems: list[str] -): - """Only report problems if multiple unique megacomplexes of the same type are used.""" - dataset_model = create_dataset_model_type({**default_dataset_properties})() - dataset_model.megacomplex = used_megacomplexes # type:ignore - dataset_model.label = "ds1" # type:ignore - problems = dataset_model.ensure_unique_megacomplexes(MockModel()) # type:ignore - - assert len(problems) == len(expected_problems) - assert problems == expected_problems +from glotaran.model.dataset_model import get_dataset_model_model_dimension +from glotaran.model.item import fill_item +from glotaran.model.item import get_item_issues +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import megacomplex +from glotaran.model.model import Model +from glotaran.parameter import Parameters + + +@megacomplex() +class MockMegacomplexNonUniqueExclusive(Megacomplex): + type: str = "test_megacomplex_not_exclusive_unique" + + +@megacomplex(exclusive=True) +class MockMegacomplexExclusive(Megacomplex): + type: str = "test_megacomplex_exclusive" + + +@megacomplex(unique=True) +class MockMegacomplexUnique(Megacomplex): + type: str = "test_megacomplex_unique" + + +@megacomplex() +class MockMegacomplexDim1(Megacomplex): + dimension: str = "dim1" + type: str = "test_megacomplex_dim1" + + +@megacomplex() +class MockMegacomplexDim2(Megacomplex): + dimension: str = "dim2" + type: str = "test_megacomplex_dim2" + + +def test_get_issues_datasetmodel(): + mcls = Model.create_class_from_megacomplexes( + [ + MockMegacomplexNonUniqueExclusive, + MockMegacomplexExclusive, + MockMegacomplexUnique, + ] + ) + m = mcls( + megacomplex={ + "m": {"type": "test_megacomplex_not_exclusive_unique"}, + "m_exclusive": {"type": "test_megacomplex_exclusive"}, + "m_unique": {"type": "test_megacomplex_unique"}, + }, + dataset={ + "ok": {"megacomplex": ["m"]}, + "exclusive": {"megacomplex": ["m", "m_exclusive"]}, + "unique": {"megacomplex": ["m_unique", "m_unique"]}, + }, + ) + + assert len(get_item_issues(item=m.dataset["ok"], model=m)) == 0 + assert len(get_item_issues(item=m.dataset["exclusive"], model=m)) == 1 + assert len(get_item_issues(item=m.dataset["unique"], model=m)) == 2 + + m = mcls( + megacomplex={ + "m": {"type": "test_megacomplex_not_exclusive_unique"}, + "m_exclusive": {"type": "test_megacomplex_exclusive"}, + "m_unique": {"type": "test_megacomplex_unique"}, + }, + dataset={ + "ok": {"megacomplex": [], "global_megacomplex": ["m"]}, + "exclusive": {"megacomplex": [], "global_megacomplex": ["m", "m_exclusive"]}, + "unique": {"megacomplex": [], "global_megacomplex": ["m_unique", "m_unique"]}, + }, + ) + + assert len(get_item_issues(item=m.dataset["ok"], model=m)) == 0 + assert len(get_item_issues(item=m.dataset["exclusive"], model=m)) == 1 + assert len(get_item_issues(item=m.dataset["unique"], model=m)) == 2 + + +def test_get_model_dim(): + mcls = Model.create_class_from_megacomplexes([MockMegacomplexDim1, MockMegacomplexDim2]) + m = mcls( + megacomplex={ + "m1": {"type": "test_megacomplex_dim1"}, + "m2": {"type": "test_megacomplex_dim2"}, + }, + dataset={ + "ok": {"megacomplex": ["m1"]}, + "error1": {"megacomplex": []}, + "error2": {"megacomplex": ["m1", "m2"]}, + }, + ) + + get_dataset_model_model_dimension( + fill_item(m.dataset["ok"], model=m, parameters=Parameters.from_list([])) + ) + with pytest.raises(ValueError, match="Dataset model 'ok' was not filled."): + get_dataset_model_model_dimension(m.dataset["ok"]) + with pytest.raises(ValueError, match="No megacomplex set for dataset model 'error1'."): + get_dataset_model_model_dimension(m.dataset["error1"]) + with pytest.raises( + ValueError, match="Megacomplex dimensions do not match for dataset model 'error2'." + ): + get_dataset_model_model_dimension( + fill_item(m.dataset["error2"], model=m, parameters=Parameters.from_list([])) + ) diff --git a/glotaran/model/test/test_item.py b/glotaran/model/test/test_item.py new file mode 100644 index 000000000..d66e14084 --- /dev/null +++ b/glotaran/model/test/test_item.py @@ -0,0 +1,123 @@ +from attrs import fields + +from glotaran.model.item import ModelItem +from glotaran.model.item import ModelItemType +from glotaran.model.item import ParameterType +from glotaran.model.item import fill_item +from glotaran.model.item import get_item_model_issues +from glotaran.model.item import get_item_parameter_issues +from glotaran.model.item import item +from glotaran.model.item import model_attributes +from glotaran.model.item import strip_type_and_structure_from_attribute +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import megacomplex +from glotaran.model.model import Model +from glotaran.parameter import Parameter +from glotaran.parameter import Parameters + + +@item +class MockModelItem(ModelItem): + p_scalar: ParameterType + p_list: list[ParameterType] + p_dict: dict[str, ParameterType] + + +@megacomplex() +class MockMegacomplexItems(Megacomplex): + type: str = "test_model_items_megacomplex" + item1: ModelItemType[MockModelItem] + item2: list[ModelItemType[MockModelItem]] + item3: dict[str, ModelItemType[MockModelItem]] + + +def test_strip_type_and_structure_from_attribute(): + @item + class MockItem: + pscalar: int = None + pscalar_option: int | None = None + plist: list[int] = None + plist_option: list[int] | None = None + pdict: dict[str, int] = None + pdict_option: dict[str, int] | None = None + iscalar: ModelItemType[int] = None + iscalar_option: ModelItemType[int] | None = None + ilist: list[ModelItemType[int]] = None + ilist_option: list[ModelItemType[int]] | None = None + idict: dict[str, ModelItemType[int]] = None + idict_option: dict[str, ModelItemType[int]] | None = None + + for attr in fields(MockItem): + structure, type_ = strip_type_and_structure_from_attribute(attr) + print(attr.name, attr.type, structure, type_) + assert structure in (None, dict, list) + assert type_ is int + + +def test_model_get_items(): + items = list(model_attributes(MockMegacomplexItems)) + + assert len(items) == 3 + assert items[0].name == "item1" + assert items[1].name == "item2" + assert items[2].name == "item3" + + +def test_get_issues(): + mcls = Model.create_class_from_megacomplexes([MockMegacomplexItems]) + model = mcls( + megacomplex={ + "m1": { + "type": "test_model_items_megacomplex", + "item1": "item1", + "item2": ["item2"], + "item3": {"foo": "item3"}, + } + }, + item1={"test": {"p_scalar": "p1", "p_list": ["p2"], "p_dict": {"p": "p2"}}}, + ) + + m = model.megacomplex["m1"] + issues = get_item_model_issues(m, model) + assert len(issues) == 3 + + p = Parameters({}) + i = model.item1["test"] + issues = get_item_parameter_issues(i, p) + assert len(issues) == 3 + + issues = model.get_issues(parameters=p) + assert len(issues) == 6 + + +def test_fill_item(): + mcls = Model.create_class_from_megacomplexes([MockMegacomplexItems]) + model = mcls( + megacomplex={ + "m1": { + "type": "test_model_items_megacomplex", + "item1": "item", + "item2": ["item"], + "item3": {"foo": "item"}, + } + }, + item1={"item": {"p_scalar": "1", "p_list": ["2"], "p_dict": {"p": "2"}}}, + item2={"item": {"p_scalar": "1", "p_list": ["2"], "p_dict": {"p": "2"}}}, + item3={"item": {"p_scalar": "1", "p_list": ["2"], "p_dict": {"p": "2"}}}, + ) + + parameters = Parameters.from_list([2, 3, 4]) + assert model.valid(parameters) + + m = fill_item(model.megacomplex["m1"], model, parameters) + assert isinstance(m.item1, MockModelItem) + assert all(isinstance(v, MockModelItem) for v in m.item2) + assert all(isinstance(v, MockModelItem) for v in m.item3.values()) + + i = m.item1 + assert isinstance(i.p_scalar, Parameter) + assert all(isinstance(v, Parameter) for v in i.p_list) + assert all(isinstance(v, Parameter) for v in i.p_dict.values()) + assert i.p_scalar.value == 2 + assert i.p_list[0].value == 3 + assert i.p_dict["p"].value == 3 diff --git a/glotaran/model/test/test_megacomplex.py b/glotaran/model/test/test_megacomplex.py new file mode 100644 index 000000000..9cb48c661 --- /dev/null +++ b/glotaran/model/test/test_megacomplex.py @@ -0,0 +1,63 @@ +from glotaran.model.dataset_model import DatasetModel +from glotaran.model.item import ModelItem +from glotaran.model.item import ModelItemType +from glotaran.model.item import item +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import megacomplex +from glotaran.model.model import Model + + +@item +class MockItem(ModelItem): + value: int + + +@item +class MockDatasetModel1(DatasetModel): + test_dataset_prop: int + + +@megacomplex(dataset_model_type=MockDatasetModel1) +class MockMegacomplex1(Megacomplex): + type: str = "mock-complex-1" + test_item: ModelItemType[MockItem] + + +@item +class MockDatasetModel2(DatasetModel): + test_dataset_str: str + + +@megacomplex(dataset_model_type=MockDatasetModel2) +class MockMegacomplex2(Megacomplex): + type: str = "mock-complex-2" + test_str: str = "foo" + + +def test_add_item_fields_to_model(): + mcls = Model.create_class_from_megacomplexes([MockMegacomplex1]) + m = mcls() + print(m) + assert isinstance(m.dataset, dict) + assert m.dataset == {} + assert isinstance(m.test_item, dict) + assert m.test_item == {} + + m = mcls( + dataset={"d1": {"megacomplex": ["m1"], "test_dataset_prop": 21}}, + megacomplex={"m1": {"type": "mock-complex-1", "test_item": "item1"}}, + test_item={"item1": {"value": 42}}, + ) + print(m) + assert "m1" in m.megacomplex + assert isinstance(m.megacomplex["m1"], MockMegacomplex1) + assert m.megacomplex["m1"].test_item == "item1" + + assert "d1" in m.dataset + assert m.dataset["d1"].label == "d1" + assert m.dataset["d1"].megacomplex == ["m1"] + assert m.dataset["d1"].test_dataset_prop == 21 + + assert "item1" in m.test_item + assert m.test_item["item1"].label == "item1" + assert m.test_item["item1"].value == 42 diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index b57ade5ad..924786288 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -1,115 +1,79 @@ -from copy import copy -from math import inf -from math import nan from textwrap import dedent -from typing import Dict -from typing import List -from typing import Tuple import pytest -from IPython.core.formatters import format_display_data - -from glotaran.io import load_parameters -from glotaran.model import DatasetModel -from glotaran.model import Megacomplex -from glotaran.model import megacomplex -from glotaran.model import model_item -from glotaran.model.clp_penalties import EqualAreaPenalty -from glotaran.model.constraint import Constraint -from glotaran.model.constraint import OnlyConstraint -from glotaran.model.constraint import ZeroConstraint -from glotaran.model.interval_property import IntervalProperty -from glotaran.model.model import Model -from glotaran.model.relation import Relation -from glotaran.model.weight import Weight -from glotaran.parameter import Parameter -from glotaran.parameter import ParameterGroup -from glotaran.testing.simulated_data.parallel_spectral_decay import MODEL - - -@model_item( - properties={ - "param": Parameter, - "megacomplex": str, - "param_list": List[Parameter], - "default_item": {"type": int, "default": 42}, - "complex": {"type": Dict[Tuple[str, str], Parameter]}, - }, -) -class MockItem: - pass - - -@model_item( - properties={ - "param": Parameter, - "param_list": List[Parameter], - "param_dict": {"type": Dict[Tuple[str, str], Parameter]}, - "number": int, - }, -) -class MockItemSimple: - pass - -@model_item(has_label=False) -class MockItemNoLabel: - pass +from glotaran.model.dataset_model import DatasetModel +from glotaran.model.item import ModelItem +from glotaran.model.item import ModelItemType +from glotaran.model.item import ParameterType +from glotaran.model.item import item +from glotaran.model.megacomplex import Megacomplex +from glotaran.model.megacomplex import megacomplex +from glotaran.model.model import DEFAULT_DATASET_GROUP +from glotaran.model.model import Model -@megacomplex(dimension="model", model_items={"test_item1": {"type": MockItem, "allow_none": True}}) -class MockMegacomplex1(Megacomplex): - pass +@item +class MockItemSimple(ModelItem): + param: ParameterType + param_list: list[ParameterType] + param_dict: dict[tuple[str, str], ParameterType] + megacomplex: ModelItemType[Megacomplex] + number: int = 42 -@megacomplex(dimension="model", model_items={"test_item2": MockItemNoLabel}) -class MockMegacomplex2(Megacomplex): - pass +@megacomplex() +class MockMegacomplexSimple(Megacomplex): + type: str = "simple" + dimension: str = "model" + test_item: ModelItemType[MockItemSimple] | None -@megacomplex(model_items={"test_item3": List[MockItem]}) -class MockMegacomplex3(Megacomplex): - pass +@megacomplex() +class MockMegacomplexItemList(Megacomplex): + type: str = "list" + dimension: str = "model" + test_item_in_list: list[ModelItemType[MockItemSimple]] -@megacomplex(dimension="model", model_items={"test_item4": Dict[str, MockItem]}) -class MockMegacomplex4(Megacomplex): - pass +@megacomplex() +class MockMegacomplexItemDict(Megacomplex): + type: str = "dict" + dimension: str = "model" + test_item_in_dict: dict[str, ModelItemType[MockItemSimple]] -@megacomplex( - dimension="model", - dataset_model_items={"test_item_dataset": MockItem}, - dataset_properties={ - "test_property_dataset1": int, - "test_property_dataset2": {"type": Parameter}, - }, -) -class MockMegacomplex5(Megacomplex): - pass +@item +class MockDatasetModel(DatasetModel): + test_item_dataset: ModelItemType[MockItemSimple] + test_property_dataset1: int + test_property_dataset2: ParameterType -@megacomplex(dimension="model", unique=True) -class MockMegacomplex6(Megacomplex): - pass +@megacomplex(dataset_model_type=MockDatasetModel) +class MockMegacomplexWithDataset(Megacomplex): + type: str = "dataset" + dimension: str = "model" -@megacomplex(dimension="model", exclusive=True) -class MockMegacomplex7(Megacomplex): - pass +@megacomplex(unique=True) +class MockMegacomplexUnique(Megacomplex): + type: str = "unique" + dimension: str = "model" -@megacomplex(dimension="model", model_items={"test_item_simple": MockItemSimple}) -class MockMegacomplex8(Megacomplex): - pass +@megacomplex(exclusive=True) +class MockMegacomplexExclusive(Megacomplex): + type: str = "exclusive" + dimension: str = "model" @pytest.fixture def test_model_dict(): model_dict = { "megacomplex": { - "m1": {"test_item1": "t2"}, - "m2": {"type": "type5", "dimension": "model2"}, + "m1": {"type": "simple", "test_item": "t2"}, + "m2": {"type": "dataset", "dimension": "model2"}, }, "dataset_groups": { "testgroup": {"residual_function": "non_negative_least_squares", "link_clp": True} @@ -122,19 +86,19 @@ def test_model_dict(): "value": 5.4, } ], - "test_item1": { + "test_item": { "t1": { "param": "foo", "megacomplex": "m1", "param_list": ["bar", "baz"], - "complex": {("s1", "s2"): "baz"}, + "param_dict": {("s1", "s2"): "baz"}, }, "t2": { "param": "baz", "megacomplex": "m2", "param_list": ["foo"], - "complex": {}, - "default_item": 7, + "param_dict": {}, + "number": 7, }, }, "dataset": { @@ -156,473 +120,314 @@ def test_model_dict(): }, }, } - model_dict["test_item_dataset"] = model_dict["test_item1"] + model_dict["test_item_dataset"] = model_dict["test_item"] return model_dict @pytest.fixture def test_model(test_model_dict): - return Model.from_dict( - test_model_dict, - megacomplex_types={ - "type1": MockMegacomplex1, - "type5": MockMegacomplex5, - }, + mcls = Model.create_class_from_megacomplexes( + [MockMegacomplexSimple, MockMegacomplexWithDataset] ) + return mcls(**test_model_dict) -@pytest.fixture -def model_error(): - model_dict = { - "megacomplex": { - "m1": {}, - "m2": {"type": "type2"}, - "m3": {"type": "type2"}, - "m4": {"type": "type3"}, - }, - "test_item1": { - "t1": { - "param": "fool", - "megacomplex": "mX", - "param_list": ["bar", "bay"], - "complex": {("s1", "s3"): "boz"}, - }, - }, - "dataset": { - "dataset1": { - "megacomplex": ["N1", "N2"], - "scale": "scale_1", - }, - "dataset2": { - "megacomplex": ["mrX", "m4"], - "scale": "scale_3", - }, - "dataset3": { - "megacomplex": ["m2", "m3"], +def test_model_create_class(): + m = Model.create_class([])(dataset={}) + print(m) + assert DEFAULT_DATASET_GROUP in m.dataset_groups + + m = Model.create_class([])( + **{ + "dataset": {}, + "dataset_groups": { + "test": {"residual_function": "non_negative_least_squares", "link_clp": False} }, - }, - } - return Model.from_dict( - model_dict, - megacomplex_types={ - "type1": MockMegacomplex1, - "type2": MockMegacomplex6, - "type3": MockMegacomplex7, - }, + } ) - - -def test_model_init(): - model = Model( - megacomplex_types={ - "type1": MockMegacomplex1, - "type2": MockMegacomplex2, - "type3": MockMegacomplex3, - "type4": MockMegacomplex4, - "type5": MockMegacomplex5, + print(m) + assert DEFAULT_DATASET_GROUP in m.dataset_groups + assert "test" in m.dataset_groups + assert m.dataset_groups["test"].residual_function == "non_negative_least_squares" + assert not m.dataset_groups["test"].link_clp + + +def test_global_items(): + + m = Model.create_class([])( + **{ + "clp_penalties": [ + { + "type": "equal_area", + "source": "s", + "source_intervals": [(1, 2)], + "target": "t", + "target_intervals": [(1, 2)], + "parameter": "p", + "weight": 1, + } + ], + "clp_constraints": [ + { + "type": "only", + "target": "t", + "interval": [(1, 2)], + }, + { + "type": "zero", + "target": "t", + "interval": (1, 2), + }, + ], + "clp_relations": [ + { + "source": "s", + "target": "t", + "interval": [(1, 2)], + "parameter": "p", + }, + ], + "dataset": {}, + "weights": [ + {"datasets": ["d1", "d2"], "value": 1}, + {"datasets": ["d3"], "value": 2, "global_interval": (5, 6)}, + ], } ) + print(m) + assert len(m.weights) == 2 + w = m.weights[0] + assert w.datasets == ["d1", "d2"] + assert w.value == 1 + assert w.model_interval is None + assert w.global_interval is None - assert model.default_megacomplex == "type1" - - assert len(model.megacomplex_types) == 5 - assert "type1" in model.megacomplex_types - assert model.megacomplex_types["type1"] == MockMegacomplex1 - assert "type2" in model.megacomplex_types - assert model.megacomplex_types["type2"] == MockMegacomplex2 - - assert hasattr(model, "test_item1") - assert isinstance(model.test_item1, dict) - assert "test_item1" in model._model_items - assert issubclass(model._model_items["test_item1"], MockItem) - - assert hasattr(model, "test_item2") - assert isinstance(model.test_item2, list) - assert "test_item2" in model._model_items - assert issubclass(model._model_items["test_item2"], MockItemNoLabel) - - assert hasattr(model, "test_item3") - assert isinstance(model.test_item3, dict) - assert "test_item3" in model._model_items - assert issubclass(model._model_items["test_item3"], MockItem) - - assert hasattr(model, "test_item4") - assert isinstance(model.test_item4, dict) - assert "test_item4" in model._model_items - assert issubclass(model._model_items["test_item4"], MockItem) - - assert hasattr(model, "test_item_dataset") - assert isinstance(model.test_item_dataset, dict) - assert "test_item_dataset" in model._model_items - assert issubclass(model._model_items["test_item_dataset"], MockItem) - assert "test_item_dataset" in model._dataset_properties - assert issubclass(model._dataset_properties["test_item_dataset"]["type"], str) - assert "test_property_dataset1" in model._dataset_properties - assert issubclass(model._dataset_properties["test_property_dataset1"], int) - assert "test_property_dataset2" in model._dataset_properties - assert issubclass(model._dataset_properties["test_property_dataset2"]["type"], Parameter) - - assert hasattr(model, "clp_area_penalties") - assert isinstance(model.clp_area_penalties, list) - assert "clp_area_penalties" in model._model_items - assert issubclass(model._model_items["clp_area_penalties"], EqualAreaPenalty) - - assert hasattr(model, "clp_constraints") - assert isinstance(model.clp_constraints, list) - assert "clp_constraints" in model._model_items - assert issubclass(model._model_items["clp_constraints"], Constraint) - - assert hasattr(model, "clp_relations") - assert isinstance(model.clp_relations, list) - assert "clp_relations" in model._model_items - assert issubclass(model._model_items["clp_relations"], Relation) - - assert hasattr(model, "weights") - assert isinstance(model.weights, list) - assert "weights" in model._model_items - assert issubclass(model._model_items["weights"], Weight) - - assert hasattr(model, "dataset") - assert isinstance(model.dataset, dict) - assert "dataset" in model._model_items - assert issubclass(model._model_items["dataset"], DatasetModel) - - -@pytest.fixture -def parameter(): - params = [1, 2, ["foo", 3], ["bar", 4], ["baz", 2], ["scale_1", 2], ["scale_2", 8], 4e2] - return ParameterGroup.from_list(params) + w = m.weights[1] + assert w.datasets == ["d3"] + assert w.value == 2 + assert w.model_interval is None + assert w.global_interval == (5, 6) -def test_model_misc(test_model: Model): - assert isinstance(test_model.megacomplex["m1"], MockMegacomplex1) - assert isinstance(test_model.megacomplex["m2"], MockMegacomplex5) +def test_model_items(test_model: Model): + assert isinstance(test_model.megacomplex["m1"], MockMegacomplexSimple) + assert isinstance(test_model.megacomplex["m2"], MockMegacomplexWithDataset) assert test_model.megacomplex["m1"].dimension == "model" assert test_model.megacomplex["m2"].dimension == "model2" + assert test_model.megacomplex["m1"].test_item == "t2" - -def test_dataset_group_models(test_model: Model): - groups = test_model.dataset_group_models - assert "default" in groups - assert groups["default"].residual_function == "variable_projection" - assert groups["default"].link_clp is None - assert "testgroup" in groups - assert groups["testgroup"].residual_function == "non_negative_least_squares" - assert groups["testgroup"].link_clp - - -def test_dataset_groups(test_model: Model): - groups = test_model.get_dataset_groups() - assert "default" in groups - assert groups["default"].residual_function == "variable_projection" - assert groups["default"].link_clp is None - assert "dataset1" in groups["default"].dataset_models - assert "testgroup" in groups - assert groups["testgroup"].residual_function == "non_negative_least_squares" - assert groups["testgroup"].link_clp - assert "dataset2" in groups["testgroup"].dataset_models - - -def test_model_validity(test_model: Model, model_error: Model, parameter: ParameterGroup): - print(test_model.test_item1["t1"]) - print(test_model.problem_list()) - print(test_model.problem_list(parameter)) - assert test_model.valid() - assert test_model.valid(parameter) - print(model_error.problem_list()) - print(model_error.problem_list(parameter)) - assert not model_error.valid() - assert len(model_error.problem_list()) == 6 - assert not model_error.valid(parameter) - assert len(model_error.problem_list(parameter)) == 10 - - -def test_model_validate_missing_parameters(): - """Show list of missing parameters as a problem.""" - - model_dict = { - "default_megacomplex": "decay-sequential", - "megacomplex": { - "megacomplex_sequential_decay": { - "type": "decay-sequential", - "compartments": ["species_1", "species_2", "species_3", "species_4"], - "rates": [ - "b.missing_value_1", - "b.missing_value_2", - "b.2", - "kinetic.j.missing_value_3", - ], - "dimension": "time", - } - }, - "dataset": { - "dataset_1": { - "group": "default", - "megacomplex": ["megacomplex_sequential_decay"], - } - }, + assert test_model.test_item["t1"].param == "foo" # type:ignore[attr-defined] + assert test_model.test_item["t1"].param_list == ["bar", "baz"] # type:ignore[attr-defined] + assert test_model.test_item["t1"].param_dict == { # type:ignore[attr-defined] + ("s1", "s2"): "baz" } - model = Model.from_dict(model_dict) - parameters = load_parameters( - dedent( - """\ - b: - - ["missing_value_1",] - - ["missing_value_2"] - - ["2", 0.75] - kinetic: - j: - - ["missing_value_3"] - """ - ), - format_name="yml_str", + assert test_model.test_item["t1"].megacomplex == "m1" # type:ignore[attr-defined] + assert test_model.test_item["t1"].number == 42 # type:ignore[attr-defined] + assert test_model.test_item["t2"].param == "baz" # type:ignore[attr-defined] + assert test_model.test_item["t2"].param_list == ["foo"] # type:ignore[attr-defined] + assert test_model.test_item["t2"].param_dict == {} # type:ignore[attr-defined] + assert test_model.test_item["t2"].megacomplex == "m2" # type:ignore[attr-defined] + assert test_model.test_item["t2"].number == 7 # type:ignore[attr-defined] + + assert test_model.dataset["dataset1"].megacomplex == ["m1"] + assert test_model.dataset["dataset1"].global_megacomplex is None + assert test_model.dataset["dataset1"].scale == "scale_1" + assert test_model.dataset["dataset1"].test_item_dataset == "t1" # type:ignore[attr-defined] + assert test_model.dataset["dataset1"].test_property_dataset1 == 1 # type:ignore[attr-defined] + assert ( + test_model.dataset["dataset1"].test_property_dataset2 == "bar" # type:ignore[attr-defined] ) - expected = dedent( - """\ - Your model has 1 problem: - - * Parameter definition is missing values for the labels: - - b.missing_value_1 - - b.missing_value_2 - - kinetic.j.missing_value_3""" + assert test_model.dataset["dataset1"].group == DEFAULT_DATASET_GROUP + + assert test_model.dataset["dataset2"].megacomplex == ["m2"] + assert test_model.dataset["dataset2"].global_megacomplex == ["m1"] + assert test_model.dataset["dataset2"].scale == "scale_2" + assert test_model.dataset["dataset2"].test_item_dataset == "t2" # type:ignore[attr-defined] + assert test_model.dataset["dataset2"].test_property_dataset1 == 1 # type:ignore[attr-defined] + assert ( + test_model.dataset["dataset2"].test_property_dataset2 == "bar" # type:ignore[attr-defined] ) - assert str(model.validate(parameters)) == expected - - -def test_items(test_model: Model): - - assert "m1" in test_model.megacomplex - assert "m2" in test_model.megacomplex - - assert "t1" in test_model.test_item1 - t = test_model.test_item1.get("t1") - assert t.param.full_label == "foo" - assert t.megacomplex == "m1" - assert [p.full_label for p in t.param_list] == ["bar", "baz"] - assert t.default_item == 42 - assert ("s1", "s2") in t.complex - assert t.complex[("s1", "s2")].full_label == "baz" - assert "t2" in test_model.test_item1 - t = test_model.test_item1.get("t2") - assert t.param.full_label == "baz" - assert t.megacomplex == "m2" - assert [p.full_label for p in t.param_list] == ["foo"] - assert t.default_item == 7 - assert t.complex == {} - - assert "dataset1" in test_model.dataset - assert test_model.dataset.get("dataset1").megacomplex == ["m1"] - assert test_model.dataset.get("dataset1").scale.full_label == "scale_1" - - assert "dataset2" in test_model.dataset - assert test_model.dataset.get("dataset2").megacomplex == ["m2"] - assert test_model.dataset.get("dataset2").global_megacomplex == ["m1"] - assert test_model.dataset.get("dataset2").scale.full_label == "scale_2" - - assert len(test_model.weights) == 1 - w = test_model.weights[0] - assert w.datasets == ["d1", "d2"] - assert w.global_interval == (1, 4) - assert w.model_interval == (2, 3) - assert w.value == 5.4 - - -def test_fill(test_model: Model, parameter: ParameterGroup): - dataset = test_model.dataset.get("dataset1").fill(test_model, parameter) - assert [cmplx.label for cmplx in dataset.megacomplex] == ["m1"] - assert dataset.scale == 2 - - assert not dataset.has_global_model() - - dataset = test_model.dataset.get("dataset2").fill(test_model, parameter) - assert [cmplx.label for cmplx in dataset.megacomplex] == ["m2"] - assert dataset.scale == 8 - - assert dataset.has_global_model() - assert [cmplx.label for cmplx in dataset.global_megacomplex] == ["m1"] - - t = test_model.test_item1.get("t1").fill(test_model, parameter) - assert t.param == 3 - assert t.megacomplex.label == "m1" - assert t.param_list == [4, 2] - assert t.default_item == 42 - assert t.complex == {("s1", "s2"): 2} - t = test_model.test_item1.get("t2").fill(test_model, parameter) - assert t.param == 2 - assert t.megacomplex.label == "m2" - assert t.param_list == [3] - assert t.default_item == 7 - assert t.complex == {} + assert test_model.dataset["dataset2"].group == "testgroup" def test_model_as_dict(): model_dict = { - "default_megacomplex": "type8", + "clp_penalties": [ + { + "type": "equal_area", + "source": "s", + "source_intervals": [(1, 2)], + "target": "t", + "target_intervals": [(1, 2)], + "parameter": "p", + "weight": 1, + } + ], + "clp_constraints": [ + {"type": "only", "target": "t", "interval": [(1, 2)]}, + {"type": "zero", "target": "t", "interval": (1, 2)}, + ], + "clp_relations": [ + {"source": "s", "target": "t", "interval": [(1, 2)], "parameter": "p"}, + ], "megacomplex": { - "m1": {"test_item_simple": "t2", "dimension": "model"}, + "m1": {"type": "simple", "label": "m1", "dimension": "model", "test_item": "t1"}, + }, + "dataset_groups": { + "default": { + "label": "default", + "residual_function": "non_negative_least_squares", + "link_clp": True, + } }, - "test_item_simple": { + "test_item": { "t1": { + "label": "t1", + "number": 4, "param": "foo", + "megacomplex": "m1", "param_list": ["bar", "baz"], "param_dict": {("s1", "s2"): "baz"}, - "number": 21, }, }, - "dataset_groups": { - "default": {"link_clp": None, "residual_function": "variable_projection"} - }, "dataset": { "dataset1": { + "label": "dataset1", + "group": "default", "megacomplex": ["m1"], + "global_megacomplex": ["m1"], "scale": "scale_1", - "group": "default", + "megacomplex_scale": "scale_1", + "global_megacomplex_scale": "scale_1", + "force_index_dependent": False, }, }, + "weights": [], } - model = Model.from_dict( - model_dict, - megacomplex_types={ - "type8": MockMegacomplex8, - }, - ) - as_model_dict = model.as_dict() + as_model_dict = Model.create_class_from_megacomplexes([MockMegacomplexSimple])( + **model_dict + ).as_dict() + print("want") + print(model_dict) + print("got") + print(as_model_dict) assert as_model_dict == model_dict -def test_model_markdown_base_heading_level(test_model: Model): - """base_heading_level applies to all sections.""" - assert test_model.markdown().startswith("# Model") - assert "## Test" in test_model.markdown() - assert test_model.markdown(base_heading_level=3).startswith("### Model") - assert "#### Test" in test_model.markdown(base_heading_level=3) - - -def test_model_ipython_rendering(test_model: Model): - """Autorendering in ipython""" - rendered_obj = format_display_data(test_model)[0] - - assert "text/markdown" in rendered_obj - assert rendered_obj["text/markdown"].startswith("### Model") +def test_model_markdown(test_model: Model): + md = test_model.markdown() + expected = dedent( + """\ + # Model - rendered_markdown_return = format_display_data(test_model.markdown())[0] + ## Dataset Groups - assert "text/markdown" in rendered_markdown_return - assert rendered_markdown_return["text/markdown"].startswith("# Model") + * **testgroup** + * *Label*: testgroup + * *Residual Function*: non_negative_least_squares + * *Link Clp*: True + * **default** + * *Label*: default + * *Residual Function*: variable_projection -def test_interval_property(): - ip1 = IntervalProperty.from_dict({"interval": [[1, 1000]]}) - assert all(ip1.applies(x) for x in (1, 500, 100)) - assert all(not ip1.applies(x) for x in (9999, inf, nan)) + ## Weights -def test_zero_constraint(): - zc1 = ZeroConstraint.from_dict({"interval": [[1, 400], [600, 1000]], "target": "s1"}) - assert all(zc1.applies(x) for x in (1, 2, 400, 600, 1000)) - assert all(not zc1.applies(x) for x in (400.01, 500, 599.99, 9999, inf, nan)) - assert zc1.target == "s1" - zc2 = ZeroConstraint.from_dict({"interval": [[600, 700]], "target": "s2"}) - assert all(zc2.applies(x) for x in range(600, 700, 50)) - assert all(not zc2.applies(x) for x in (599.9999, 700.0001)) - assert zc2.target == "s2" + * **** + * *Datasets*: ['d1', 'd2'] + * *Global Interval*: (1, 4) + * *Model Interval*: (2, 3) + * *Value*: 5.4 -def test_only_constraint(): - oc1 = OnlyConstraint.from_dict({"interval": [[1, 400], (600, 1000)], "target": "spectra1"}) - assert all(oc1.applies(x) for x in (400.01, 500, 599.99, 9999, inf)) - assert all(not oc1.applies(x) for x in (1, 400, 600, 1000)) - assert oc1.target == "spectra1" - oc2 = OnlyConstraint.from_dict({"interval": [(600, 700)], "target": "spectra2"}) - assert oc2.applies(599) - assert not oc2.applies(650) - assert oc2.applies(701) - assert oc2.target == "spectra2" + ## Test Item + * **t1** + * *Label*: t1 + * *Param*: foo + * *Param List*: ['bar', 'baz'] + * *Param Dict*: {('s1', 's2'): 'baz'} + * *Megacomplex*: m1 + * *Number*: 42 -def test_model_markdown(): - """Full markdown string is as expected.""" - expected = dedent( - """\ - # Model + * **t2** + * *Label*: t2 + * *Param*: baz + * *Param List*: ['foo'] + * *Param Dict*: {} + * *Megacomplex*: m2 + * *Number*: 7 - _Megacomplex Types_: decay-parallel - ## Dataset Groups + ## Megacomplex - * **default**: - * *Label*: default - * *residual_function*: variable_projection - * *link_clp*: None + * **m1** + * *Label*: m1 + * *Type*: simple + * *Dimension*: model + * *Test Item*: t2 - ## Irf + * **m2** + * *Label*: m2 + * *Type*: dataset + * *Dimension*: model2 - * **gaussian_irf** (gaussian): - * *Label*: gaussian_irf - * *Type*: gaussian - * *Center*: irf.center(nan) - * *Width*: irf.width(nan) - * *Normalize*: True - * *Backsweep*: False + ## Test Item Dataset - ## Megacomplex + * **t1** + * *Label*: t1 + * *Param*: foo + * *Param List*: ['bar', 'baz'] + * *Param Dict*: {('s1', 's2'): 'baz'} + * *Megacomplex*: m1 + * *Number*: 42 - * **megacomplex_parallel_decay** (decay-parallel): - * *Label*: megacomplex_parallel_decay - * *Type*: decay-parallel - * *Compartments*: - * species_1 - * species_2 - * species_3 - * *Rates*: - * rates.species_1(nan) - * rates.species_2(nan) - * rates.species_3(nan) - * *Dimension*: time + * **t2** + * *Label*: t2 + * *Param*: baz + * *Param List*: ['foo'] + * *Param Dict*: {} + * *Megacomplex*: m2 + * *Number*: 7 ## Dataset - * **dataset_1**: - * *Label*: dataset_1 + * **dataset1** + * *Label*: dataset1 * *Group*: default - * *Megacomplex*: - * megacomplex_parallel_decay - * *Irf*: gaussian_irf + * *Force Index Dependent*: False + * *Megacomplex*: ['m1'] + * *Scale*: scale_1 + * *Test Item Dataset*: t1 + * *Test Property Dataset1*: 1 + * *Test Property Dataset2*: bar + + * **dataset2** + * *Label*: dataset2 + * *Group*: testgroup + * *Force Index Dependent*: False + * *Megacomplex*: ['m2'] + * *Global Megacomplex*: ['m1'] + * *Scale*: scale_2 + * *Test Item Dataset*: t2 + * *Test Property Dataset1*: 1 + * *Test Property Dataset2*: bar """ ) - model = copy(MODEL) - model.dataset_group_models["default"].link_clp = None - + print(md) # Preprocessing to remove trailing whitespace after '* *Matrix*:' - result = "\n".join([line.rstrip(" ") for line in str(MODEL.markdown()).split("\n")]) - print(result) - + expected = "\n".join([line.rstrip(" ") for line in str(expected).split("\n")]) + result = "\n".join([line.rstrip(" ") for line in str(md).split("\n")]) assert result == expected def test_get_parameter_labels(test_model: Model): - wanted = [ - "foo", - "bar", - "baz", - "baz", - "baz", - "foo", - "foo", - "bar", - "baz", - "baz", - "baz", - "foo", - "scale_1", - "bar", - "scale_2", - "bar", - ] + wanted = {"foo", "scale_1", "scale_2", "baz", "bar"} got = test_model.get_parameter_labels() - + print(got) assert wanted == got diff --git a/glotaran/model/test/test_model_property.py b/glotaran/model/test/test_model_property.py deleted file mode 100644 index b6fe18b30..000000000 --- a/glotaran/model/test/test_model_property.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Tests for glotaran.model.property.ModelProperty""" -from __future__ import annotations - -from typing import Dict -from typing import List - -from glotaran.model.property import ModelProperty -from glotaran.model.property import ParameterOrLabel -from glotaran.parameter import Parameter - - -def test_model_property_non_parameter(): - class MockClass: - pass - - p_scalar = ModelProperty(MockClass, "scalar", int, "", None, True) - assert p_scalar.glotaran_is_scalar_property - assert not p_scalar.glotaran_is_sequence_property - assert not p_scalar.glotaran_is_mapping_property - assert p_scalar.glotaran_property_subtype is int - assert not p_scalar.glotaran_is_parameter_property - assert p_scalar.glotaran_value_as_markdown(42) == "42" - - p_sequence = ModelProperty(MockClass, "sequence", List[int], "", None, True) - assert not p_sequence.glotaran_is_scalar_property - assert p_sequence.glotaran_is_sequence_property - assert not p_sequence.glotaran_is_mapping_property - assert p_sequence.glotaran_property_subtype is int - assert not p_sequence.glotaran_is_parameter_property - print(p_sequence.glotaran_value_as_markdown([1, 2])) - assert p_sequence.glotaran_value_as_markdown([1, 2]) == "\n * 1\n * 2" - - p_mapping = ModelProperty(MockClass, "mapping", Dict[str, int], "", None, True) - assert not p_mapping.glotaran_is_scalar_property - assert not p_mapping.glotaran_is_sequence_property - assert p_mapping.glotaran_is_mapping_property - assert p_mapping.glotaran_property_subtype is int - assert not p_mapping.glotaran_is_parameter_property - print(p_mapping.glotaran_value_as_markdown({"a": 1, "b": 2})) - assert p_mapping.glotaran_value_as_markdown({"a": 1, "b": 2}) == "\n * a: 1\n * b: 2" - - -def test_model_property_parameter(): - class MockClass: - pass - - p_scalar = ModelProperty(MockClass, "scalar", Parameter, "", None, True) - assert p_scalar.glotaran_is_scalar_property - assert not p_scalar.glotaran_is_sequence_property - assert not p_scalar.glotaran_is_mapping_property - assert p_scalar.glotaran_property_subtype is ParameterOrLabel - assert p_scalar.glotaran_is_parameter_property - - p_sequence = ModelProperty(MockClass, "sequence", List[Parameter], "", None, True) - assert not p_sequence.glotaran_is_scalar_property - assert p_sequence.glotaran_is_sequence_property - assert not p_sequence.glotaran_is_mapping_property - assert p_sequence.glotaran_property_subtype is ParameterOrLabel - assert p_sequence.glotaran_is_parameter_property - - p_mapping = ModelProperty(MockClass, "mapping", Dict[str, Parameter], "", None, True) - assert not p_mapping.glotaran_is_scalar_property - assert not p_mapping.glotaran_is_sequence_property - assert p_mapping.glotaran_is_mapping_property - assert p_mapping.glotaran_property_subtype is ParameterOrLabel - assert p_mapping.glotaran_is_parameter_property - - -def test_model_property_default_getter(): - class MockClass: - _p_default = None - - p_default = ModelProperty(MockClass, "p_default", int, "", 42, True) - assert p_default.fget(MockClass) == 42 - MockClass._p_default = 21 - assert p_default.fget(MockClass) == 21 - - -def test_model_property_parameter_setter(): - class MockClass: - pass - - p_scalar = ModelProperty(MockClass, "scalar", Parameter, "", None, True) - p_scalar.fset(MockClass, "param.foo") - value = p_scalar.fget(MockClass) - assert isinstance(value, Parameter) - assert value.full_label == "param.foo" - - p_sequence = ModelProperty(MockClass, "sequence", List[Parameter], "", None, True) - names = ["param1", "param2"] - p_sequence.fset(MockClass, names) - value = p_sequence.fget(MockClass) - assert isinstance(value, list) - assert all(isinstance(v, Parameter) for v in value) - assert [p.full_label for p in value] == names - - p_mapping = ModelProperty(MockClass, "mapping", Dict[str, Parameter], "", None, True) - p_mapping.fset(MockClass, {f"{i}": n for i, n in enumerate(names)}) - value = p_mapping.fget(MockClass) - assert isinstance(value, dict) - assert all(isinstance(v, Parameter) for v in value.values()) - assert [p.full_label for p in value.values()] == names - - -def test_model_property_parameter_to_label(): - class MockClass: - pass - - p_scalar = ModelProperty(MockClass, "scalar", Parameter, "", None, True) - p_scalar.fset(MockClass, "param.foo") - value = p_scalar.fget(MockClass) - assert p_scalar.glotaran_replace_parameter_with_labels(value) == "param.foo" - - p_sequence = ModelProperty(MockClass, "sequence", List[Parameter], "", None, True) - names = ["param1", "param2"] - p_sequence.fset(MockClass, names) - value = p_sequence.fget(MockClass) - assert p_sequence.glotaran_replace_parameter_with_labels(value) == names - - p_mapping = ModelProperty(MockClass, "mapping", Dict[str, Parameter], "", None, True) - p_mapping.fset(MockClass, {f"{i}": n for i, n in enumerate(names)}) - value = p_mapping.fget(MockClass) - assert list(p_mapping.glotaran_replace_parameter_with_labels(value).values()) == names diff --git a/glotaran/model/util.py b/glotaran/model/util.py deleted file mode 100644 index 5aeb1d163..000000000 --- a/glotaran/model/util.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Helper functions.""" -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Mapping -from typing import Sequence -from typing import Union - -if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import TypeVar - - DecoratedFunc = TypeVar("DecoratedFunc", bound=Callable[..., Any]) # decorated function - - -class ModelError(Exception): - """Raised when a model contains errors.""" - - def __init__(self, error: str): - super().__init__(f"ModelError: {error}") - - -def wrap_func_as_method( - cls: Any, name: str = None, annotations: dict[str, type] = None, doc: str = None -) -> Callable[[DecoratedFunc], DecoratedFunc]: - """A decorator to wrap a function as class method. - - Notes - ----- - - Only for internal use. - - Parameters - ---------- - cls : - The class in which the function will be wrapped. - name : - The name of method. If `None`, the original function's name is used. - annotations : - The annotations of the method. If `None`, the original function's annotations are used. - doc : - The documentation of the method. If `None`, the original function's documentation is used. - """ - - def wrapper(func: DecoratedFunc) -> DecoratedFunc: - if name: - func.__name__ = name - if annotations: - setattr(func, "__annotations__", annotations) - if doc: - func.__doc__ = doc - func.__qualname__ = cls.__qualname__ + "." + func.__name__ - func.__module__ = cls.__module__ - - return func - - return wrapper - - -def is_scalar_type(t: type) -> bool: - """Check if the type is scalar. - - Scalar means the type is neither a sequence nor a mapping. - - Parameters - ---------- - t : type - The type to check. - - Returns - ------- - bool - Whether the type is scalar. - """ - if hasattr(t, "__origin__"): - # Union can for some reason not be used in issubclass - return t.__origin__ is Union or not issubclass(t.__origin__, (Sequence, Mapping)) - return True - - -def is_sequence_type(t: type) -> bool: - """Check if the type is a sequence. - - Parameters - ---------- - t : type - The type to check. - - Returns - ------- - bool - Whether the type is a sequence. - """ - return not is_scalar_type(t) and issubclass(t.__origin__, Sequence) - - -def is_mapping_type(t: type) -> bool: - """Check if the type is mapping. - - Parameters - ---------- - t : type - The type to check. - - Returns - ------- - bool - Whether the type is a mapping. - """ - return not is_scalar_type(t) and issubclass(t.__origin__, Mapping) - - -def get_subtype(t: type) -> type: - """Gets the subscribed type of a generic type. - - If the type is scalar, the type itself will be returned. If the type is a mapping, - the value type will be returned. - - Parameters - ---------- - t : type - The origin type. - - Returns - ------- - type - The subscribed type. - """ - if is_sequence_type(t): - return t.__args__[0] - elif is_mapping_type(t): - return t.__args__[1] - return t diff --git a/glotaran/model/weight.py b/glotaran/model/weight.py index 18408c467..6f4d0b3a7 100644 --- a/glotaran/model/weight.py +++ b/glotaran/model/weight.py @@ -1,31 +1,18 @@ -"""The Weight property class.""" +"""This module contains weight item.""" -from typing import List -from typing import Tuple +from glotaran.model.item import Item +from glotaran.model.item import item -from glotaran.model.item import model_item - -@model_item( - properties={ - "datasets": {type: List[str]}, - "global_interval": { - "type": List[Tuple[float, float]], - "default": None, - "allow_none": True, - }, - "model_interval": { - "type": List[Tuple[float, float]], - "default": None, - "allow_none": True, - }, - "value": {"type": float}, - }, - has_label=False, -) -class Weight: +@item +class Weight(Item): """The `Weight` class describes a value by which a dataset will scaled. `global_interval` and `model_interval` are optional. The whole range of the dataset will be used if not set. """ + + datasets: list[str] + global_interval: tuple[float, float] | None = None + model_interval: tuple[float, float] | None = None + value: float diff --git a/glotaran/optimization/data_provider.py b/glotaran/optimization/data_provider.py index e8c2e0556..76a38ddb9 100644 --- a/glotaran/optimization/data_provider.py +++ b/glotaran/optimization/data_provider.py @@ -9,6 +9,8 @@ from glotaran.model import DatasetGroup from glotaran.model import Model +from glotaran.model.dataset_model import get_dataset_model_model_dimension +from glotaran.model.dataset_model import has_dataset_model_global_model from glotaran.project import Scheme @@ -48,7 +50,7 @@ def __init__(self, scheme: Scheme, dataset_group: DatasetGroup): for label, dataset_model in dataset_group.dataset_models.items(): dataset = scheme.data[label] - model_dimension = dataset_model.get_model_dimension() + model_dimension = get_dataset_model_model_dimension(dataset_model) self._model_axes[label] = dataset.coords[model_dimension].data self._model_dimensions[label] = model_dimension global_dimension = self.infer_global_dimension(model_dimension, dataset.data.dims) @@ -66,7 +68,7 @@ def __init__(self, scheme: Scheme, dataset_group: DatasetGroup): if self._weight[label] is not None: self._data[label] *= self._weight[label] - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): self._flattened_data[label] = self._data[label].T.flatten() self._flattened_weight[label] = ( self._weight[label].T.flatten() # type:ignore[union-attr] @@ -194,15 +196,15 @@ def add_model_weight( for model_weight in model_weights: idx = {} - if model_weight.global_interval is not None: # type:ignore[attr-defined] + if model_weight.global_interval is not None: idx[global_dimension] = self.get_axis_slice_from_interval( - model_weight.global_interval, global_axis # type:ignore[attr-defined] + model_weight.global_interval, global_axis ) - if model_weight.model_interval is not None: # type:ignore[attr-defined] + if model_weight.model_interval is not None: idx[model_dimension] = self.get_axis_slice_from_interval( - model_weight.model_interval, model_axis # type:ignore[attr-defined] + model_weight.model_interval, model_axis ) - weight[idx] *= model_weight.value # type:ignore[attr-defined] + weight[idx] *= model_weight.value self._weight[dataset_label] = weight.data diff --git a/glotaran/optimization/estimation_provider.py b/glotaran/optimization/estimation_provider.py index 2695b162f..9c5db2dcc 100644 --- a/glotaran/optimization/estimation_provider.py +++ b/glotaran/optimization/estimation_provider.py @@ -8,6 +8,10 @@ from glotaran.model import DatasetGroup from glotaran.model import DatasetModel +from glotaran.model import EqualAreaPenalty +from glotaran.model.dataset_model import has_dataset_model_global_model +from glotaran.model.dataset_model import is_dataset_model_index_dependent +from glotaran.model.item import fill_item from glotaran.optimization.data_provider import DataProvider from glotaran.optimization.data_provider import DataProviderLinked from glotaran.optimization.matrix_provider import MatrixProviderLinked @@ -128,17 +132,15 @@ def retrieve_clps( clps[idx] = reduced_clps[i] for relation in model.clp_relations: - relation = relation.fill(model, parameters) # type:ignore[attr-defined] + relation = fill_item(relation, model, parameters) # type:ignore[arg-type] if ( - relation.target in clp_labels # type:ignore[attr-defined] - and relation.applies(index) # type:ignore[attr-defined] - and relation.source in clp_labels # type:ignore[attr-defined] + relation.target in clp_labels + and relation.applies(index) + and relation.source in clp_labels ): - source_idx = clp_labels.index(relation.source) # type:ignore[attr-defined] - target_idx = clp_labels.index(relation.target) # type:ignore[attr-defined] - clps[target_idx] = ( - relation.parameter * clps[source_idx] # type:ignore[attr-defined] - ) + source_idx = clp_labels.index(relation.source) + target_idx = clp_labels.index(relation.target) + clps[target_idx] = relation.parameter * clps[source_idx] return clps def get_additional_penalties(self) -> list[float]: @@ -176,8 +178,10 @@ def calculate_clp_penalties( model = self.group.model parameters = self.group.parameters penalties = [] - for penalty in model.clp_area_penalties: - penalty = penalty.fill(model, parameters) + for penalty in model.clp_penalties: + if not isinstance(penalty, EqualAreaPenalty): + continue + penalty = fill_item(penalty, model, parameters) # type:ignore[arg-type] source_area = _get_area( penalty.source, @@ -284,7 +288,7 @@ def estimate(self): self._clp_penalty.clear() for dataset_model in self.group.dataset_models.values(): - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): self.calculate_full_model_estimation(dataset_model) else: self.calculate_estimation(dataset_model) @@ -300,7 +304,7 @@ def get_full_penalty(self) -> np.typing.ArrayLike: full_penalty = np.concatenate( [ self._residuals[label] - if dataset_model.has_global_model() + if has_dataset_model_global_model(dataset_model) else np.concatenate(self._residuals[label]) for label, dataset_model in self.group.dataset_models.items() ] @@ -326,7 +330,7 @@ def get_result( global_dimension = self._data_provider.get_global_dimension(label) global_axis = self._data_provider.get_global_axis(label) - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): residuals[label] = xr.DataArray( np.array(self._residuals[label]).T.reshape(model_axis.size, global_axis.size), coords={global_dimension: global_axis, model_dimension: model_axis}, @@ -348,7 +352,7 @@ def get_result( coords={global_dimension: global_axis, model_dimension: model_axis}, dims=[model_dimension, global_dimension], ) - if dataset_model.is_index_dependent(): + if is_dataset_model_index_dependent(dataset_model): clps[label] = xr.concat( [ xr.DataArray( diff --git a/glotaran/optimization/matrix_provider.py b/glotaran/optimization/matrix_provider.py index 8b6250274..b07fb4d69 100644 --- a/glotaran/optimization/matrix_provider.py +++ b/glotaran/optimization/matrix_provider.py @@ -10,7 +10,12 @@ from glotaran.model import DatasetGroup from glotaran.model import DatasetModel -from glotaran.model.interval_property import IntervalProperty +from glotaran.model.dataset_model import has_dataset_model_global_model +from glotaran.model.dataset_model import is_dataset_model_index_dependent +from glotaran.model.dataset_model import iterate_dataset_model_global_megacomplexes +from glotaran.model.dataset_model import iterate_dataset_model_megacomplexes +from glotaran.model.interval_item import IntervalItem +from glotaran.model.item import fill_item from glotaran.optimization.data_provider import DataProvider from glotaran.optimization.data_provider import DataProviderLinked @@ -118,7 +123,7 @@ def get_matrix_container(self, dataset_label: str, global_index: int) -> MatrixC The matrix container. """ matrix_container = self._matrix_containers[dataset_label] - if self.group.dataset_models[dataset_label].is_index_dependent(): + if is_dataset_model_index_dependent(self.group.dataset_models[dataset_label]): matrix_container = matrix_container[global_index] # type:ignore[index] return matrix_container # type:ignore[return-value] @@ -128,7 +133,7 @@ def calculate_dataset_matrices(self): model_axis = self._data_provider.get_model_axis(label) global_axis = self._data_provider.get_global_axis(label) - if dataset_model.is_index_dependent(): + if is_dataset_model_index_dependent(dataset_model): self._matrix_containers[label] = [ self.calculate_dataset_matrix( dataset_model, global_index, global_axis, model_axis @@ -171,13 +176,13 @@ def calculate_dataset_matrix( clp_labels: list[str] = [] matrix = None - megacomplex_iterator = dataset_model.iterate_megacomplexes + megacomplex_iterator = iterate_dataset_model_megacomplexes(dataset_model) if global_matrix: - megacomplex_iterator = dataset_model.iterate_global_megacomplexes + megacomplex_iterator = iterate_dataset_model_global_megacomplexes(dataset_model) model_axis, global_axis = global_axis, model_axis - for scale, megacomplex in megacomplex_iterator(): + for scale, megacomplex in megacomplex_iterator: this_clp_labels, this_matrix = megacomplex.calculate_matrix( # type:ignore[union-attr] dataset_model, global_index, global_axis, model_axis ) @@ -231,12 +236,12 @@ def combine_megacomplex_matrices( return tmp_clp_labels, tmp_matrix @staticmethod - def does_interval_property_apply(prop: IntervalProperty, index: int | None) -> bool: + def does_interval_property_apply(prop: IntervalItem, index: int | None) -> bool: """Check if an interval property applies on an index. Parameters ---------- - prop : IntervalProperty + prop : IntervalItem The interval property. index: int | None The index to check. @@ -344,23 +349,18 @@ def apply_relations( relation_matrix = np.diagflat([1.0 for _ in clp_labels]) idx_to_delete = [] - for relation in model.clp_relations: # type:ignore[attr-defined] - if ( - relation.target in clp_labels # type:ignore[attr-defined] - and self.does_interval_property_apply( - relation, index # type:ignore[arg-type] - ) + for relation in model.clp_relations: + if relation.target in clp_labels and self.does_interval_property_apply( + relation, index ): - if relation.source not in clp_labels: # type:ignore[attr-defined] + if relation.source not in clp_labels: continue - relation = relation.fill(model, parameters) # type:ignore[attr-defined] - source_idx = clp_labels.index(relation.source) # type:ignore[attr-defined] - target_idx = clp_labels.index(relation.target) # type:ignore[attr-defined] - relation_matrix[ - target_idx, source_idx - ] = relation.parameter # type:ignore[attr-defined] + relation = fill_item(relation, model, parameters) # type:ignore[arg-type] + source_idx = clp_labels.index(relation.source) + target_idx = clp_labels.index(relation.target) + relation_matrix[target_idx, source_idx] = relation.parameter idx_to_delete.append(target_idx) reduced_clp_labels = [ @@ -386,7 +386,7 @@ def get_result(self) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]]: for label, matrix_container in self._matrix_containers.items(): model_dimension = self._data_provider.get_model_dimension(label) model_axis = self._data_provider.get_model_axis(label) - if self.group.dataset_models[label].is_index_dependent(): + if is_dataset_model_index_dependent(self.group.dataset_models[label]): global_dimension = self._data_provider.get_global_dimension(label) global_axis = self._data_provider.get_global_axis(label) matrices[label] = xr.concat( @@ -509,7 +509,7 @@ def calculate(self): def calculate_global_matrices(self): """Calculate the global matrices of the datasets in the dataset group.""" for label, dataset_model in self.group.dataset_models.items(): - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): model_axis = self._data_provider.get_model_axis(label) global_axis = self._data_provider.get_global_axis(label) self._global_matrix_containers[label] = self.calculate_dataset_matrix( @@ -519,11 +519,11 @@ def calculate_global_matrices(self): def calculate_prepared_matrices(self): """Calculate the prepared matrices of the datasets in the dataset group.""" for label, dataset_model in self.group.dataset_models.items(): - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): continue scale = dataset_model.scale or 1 weight = self._data_provider.get_weight(label) - if dataset_model.is_index_dependent(): + if is_dataset_model_index_dependent(dataset_model): self._prepared_matrix_container[label] = [ self.reduce_matrix( self.get_matrix_container(label, i).create_scaled_matrix(scale), @@ -546,10 +546,10 @@ def calculate_prepared_matrices(self): def calculate_full_matrices(self): """Calculate the full matrices of the datasets in the dataset group.""" for label, dataset_model in self.group.dataset_models.items(): - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): global_matrix_container = self.get_global_matrix_container(label) - if dataset_model.is_index_dependent(): + if is_dataset_model_index_dependent(dataset_model): global_axis = self._data_provider.get_global_axis(label) full_matrix = np.concatenate( [ diff --git a/glotaran/optimization/optimization_group.py b/glotaran/optimization/optimization_group.py index 763092ae5..073823025 100644 --- a/glotaran/optimization/optimization_group.py +++ b/glotaran/optimization/optimization_group.py @@ -6,6 +6,7 @@ from glotaran.io.prepare_dataset import add_svd_to_dataset from glotaran.model import DatasetGroup +from glotaran.model.dataset_model import finalize_dataset_model from glotaran.optimization.data_provider import DataProvider from glotaran.optimization.data_provider import DataProviderLinked from glotaran.optimization.estimation_provider import EstimationProvider @@ -14,7 +15,7 @@ from glotaran.optimization.matrix_provider import MatrixProvider from glotaran.optimization.matrix_provider import MatrixProviderLinked from glotaran.optimization.matrix_provider import MatrixProviderUnlinked -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme @@ -71,12 +72,12 @@ def __init__( dataset.data.dims[1], ) - def calculate(self, parameters: ParameterGroup): + def calculate(self, parameters: Parameters): """Calculate the optimization group data. Parameters ---------- - parameters : ParameterGroup + parameters : Parameters The parameters. """ self._dataset_group.set_parameters(parameters) @@ -176,13 +177,15 @@ def create_result_data(self) -> dict[str, xr.Dataset]: ) result_dataset.attrs["dataset_scale"] = ( - 1 if dataset_model.scale is None else dataset_model.scale.value + 1 + if dataset_model.scale is None + else dataset_model.scale.value # type:ignore[union-attr] ) # reconstruct fitted data result_dataset["fitted_data"] = result_dataset.data - result_dataset.residual - dataset_model.finalize_data(result_dataset) + finalize_dataset_model(dataset_model, result_dataset) return result_datasets diff --git a/glotaran/optimization/test/models.py b/glotaran/optimization/test/models.py index 400c41beb..7c08bd084 100644 --- a/glotaran/optimization/test/models.py +++ b/glotaran/optimization/test/models.py @@ -1,18 +1,21 @@ from __future__ import annotations -from typing import List - import numpy as np from glotaran.model import DatasetModel from glotaran.model import Megacomplex from glotaran.model import Model +from glotaran.model import ParameterType +from glotaran.model import item from glotaran.model import megacomplex -from glotaran.parameter import Parameter -@megacomplex(dimension="model", properties={"is_index_dependent": bool}) +@megacomplex() class SimpleTestMegacomplex(Megacomplex): + type: str = "simple-test-mc" + dimension: str = "model" + is_index_dependent: bool + def calculate_matrix( self, dataset_model: DatasetModel, @@ -46,33 +49,20 @@ def finalize_data( pass -class SimpleTestModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ): - defaults: dict[str, type[Megacomplex]] = {"model_complex": SimpleTestMegacomplex} - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) - - -@megacomplex( - dimension="model", - properties={"is_index_dependent": bool}, - dataset_properties={ - "kinetic": List[Parameter], - }, -) +SimpleTestModel = Model.create_class_from_megacomplexes([SimpleTestMegacomplex]) + + +@item +class SimpleDatasetModel(DatasetModel): + kinetic: list[ParameterType] + + +@megacomplex(dataset_model_type=SimpleDatasetModel) class SimpleKineticMegacomplex(Megacomplex): + type: str = "simple-kinetic-test-mc" + dimension: str = "model" + is_index_dependent: bool + def calculate_matrix( self, dataset_model, @@ -103,8 +93,11 @@ def finalize_data( pass -@megacomplex(dimension="global", properties={}) +@megacomplex() class SimpleSpectralMegacomplex(Megacomplex): + type: str = "simple-spectral-test-mc" + dimension: str = "global" + def calculate_matrix( self, dataset_model, @@ -126,15 +119,14 @@ def index_dependent(self, dataset_model): return False -@megacomplex( - dimension="global", - properties={ - "location": {"type": List[Parameter], "allow_none": True}, - "amplitude": {"type": List[Parameter], "allow_none": True}, - "delta": {"type": List[Parameter], "allow_none": True}, - }, -) +@megacomplex() class ShapedSpectralMegacomplex(Megacomplex): + type: str = "shaped-spectral-test-mc" + dimension: str = "global" + location: list[ParameterType] + amplitude: list[ParameterType] + delta: list[ParameterType] + def calculate_matrix( self, dataset_model, @@ -169,24 +161,6 @@ def finalize_data( pass -class DecayModel(Model): - @classmethod - def from_dict( - cls, - model_dict, - *, - megacomplex_types: dict[str, type[Megacomplex]] | None = None, - default_megacomplex_type: str | None = None, - ): - defaults: dict[str, type[Megacomplex]] = { - "model_complex": SimpleKineticMegacomplex, - "global_complex": SimpleSpectralMegacomplex, - "global_complex_shaped": ShapedSpectralMegacomplex, - } - if megacomplex_types is not None: - defaults.update(megacomplex_types) - return super().from_dict( - model_dict, - megacomplex_types=defaults, - default_megacomplex_type=default_megacomplex_type, - ) +DecayModel = Model.create_class_from_megacomplexes( + [SimpleKineticMegacomplex, SimpleSpectralMegacomplex, ShapedSpectralMegacomplex] +) diff --git a/glotaran/optimization/test/suites.py b/glotaran/optimization/test/suites.py index c14719812..6b2ba522f 100644 --- a/glotaran/optimization/test/suites.py +++ b/glotaran/optimization/test/suites.py @@ -3,34 +3,35 @@ import numpy as np from glotaran.optimization.test.models import DecayModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters class OneCompartmentDecay: scale = 2 - wanted_parameters = ParameterGroup.from_list([101e-4]) - initial_parameters = ParameterGroup.from_list([100e-5, [scale, {"vary": False}]]) + wanted_parameters = Parameters.from_list([101e-4]) + initial_parameters = Parameters.from_list([100e-5, [scale, {"vary": False}]]) global_axis = np.asarray([1.0]) model_axis = np.arange(0, 150, 1.5) sim_model_dict = { - "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "megacomplex": { + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, + "m2": {"type": "simple-spectral-test-mc"}, + }, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["1"], } }, } - sim_model = DecayModel.from_dict(sim_model_dict) + sim_model = DecayModel(**sim_model_dict) model_dict = { - "megacomplex": {"m1": {"is_index_dependent": False}}, + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1"], "scale": "2", @@ -38,22 +39,24 @@ class OneCompartmentDecay: }, } model_dict["dataset"]["dataset1"]["scale"] = "2" # type:ignore[index] - model = DecayModel.from_dict(model_dict) + model = DecayModel(**model_dict) class TwoCompartmentDecay: - wanted_parameters = ParameterGroup.from_list([11e-4, 22e-5]) - initial_parameters = ParameterGroup.from_list([10e-4, 20e-5]) + wanted_parameters = Parameters.from_list([11e-4, 22e-5]) + initial_parameters = Parameters.from_list([10e-4, 20e-5]) global_axis = np.asarray([1.0]) model_axis = np.arange(0, 150, 1.5) - sim_model = DecayModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + sim_model = DecayModel( + **{ + "megacomplex": { + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, + "m2": {"type": "simple-spectral-test-mc"}, + }, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["1", "2"], @@ -61,12 +64,11 @@ class TwoCompartmentDecay: }, } ) - model = DecayModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}}, + model = DecayModel( + **{ + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1", "2"], } @@ -76,8 +78,8 @@ class TwoCompartmentDecay: class ThreeDatasetDecay: - wanted_parameters = ParameterGroup.from_list([101e-4, 201e-3]) - initial_parameters = ParameterGroup.from_list([100e-5, 200e-3]) + wanted_parameters = Parameters.from_list([101e-4, 201e-3]) + initial_parameters = Parameters.from_list([100e-5, 200e-3]) global_axis = np.asarray([1.0]) model_axis = np.arange(0, 150, 1.5) @@ -89,47 +91,46 @@ class ThreeDatasetDecay: model_axis3 = np.arange(0, 150, 1.5) sim_model_dict = { - "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "megacomplex": { + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, + "m2": {"type": "simple-spectral-test-mc"}, + }, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["1"], }, "dataset2": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["1", "2"], }, "dataset3": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["2"], }, }, } - sim_model = DecayModel.from_dict(sim_model_dict) + sim_model = DecayModel(**sim_model_dict) model_dict = { - "megacomplex": {"m1": {"is_index_dependent": False}}, + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { - "dataset1": {"initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1"]}, + "dataset1": {"megacomplex": ["m1"], "kinetic": ["1"]}, "dataset2": { - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1", "2"], }, - "dataset3": {"initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["2"]}, + "dataset3": {"megacomplex": ["m1"], "kinetic": ["2"]}, }, } - model = DecayModel.from_dict(model_dict) + model = DecayModel(**model_dict) class MultichannelMulticomponentDecay: - wanted_parameters = ParameterGroup.from_dict( + wanted_parameters = Parameters.from_dict( { "k": [0.006, 0.003, 0.0003, 0.03], "loc": [ @@ -152,17 +153,17 @@ class MultichannelMulticomponentDecay: ], } ) - initial_parameters = ParameterGroup.from_dict({"k": [0.006, 0.003, 0.0003, 0.03]}) + initial_parameters = Parameters.from_dict({"k": [0.006, 0.003, 0.0003, 0.03]}) global_axis = np.arange(12820, 15120, 50) model_axis = np.arange(0, 150, 1.5) - sim_model = DecayModel.from_dict( - { + sim_model = DecayModel( + **{ "megacomplex": { - "m1": {"is_index_dependent": False}, + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, "m2": { - "type": "global_complex_shaped", + "type": "shaped-spectral-test-mc", "location": ["loc.1", "loc.2", "loc.3", "loc.4"], "delta": ["del.1", "del.2", "del.3", "del.4"], "amplitude": ["amp.1", "amp.2", "amp.3", "amp.4"], @@ -177,9 +178,9 @@ class MultichannelMulticomponentDecay: }, } ) - model = DecayModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}}, + model = DecayModel( + **{ + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { "megacomplex": ["m1"], @@ -191,12 +192,12 @@ class MultichannelMulticomponentDecay: class FullModel: - model = DecayModel.from_dict( - { + model = DecayModel( + **{ "megacomplex": { - "m1": {"is_index_dependent": False}, + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, "m2": { - "type": "global_complex_shaped", + "type": "shaped-spectral-test-mc", "location": ["loc.1", "loc.2", "loc.3", "loc.4"], "delta": ["del.1", "del.2", "del.3", "del.4"], "amplitude": ["amp.1", "amp.2", "amp.3", "amp.4"], @@ -211,7 +212,7 @@ class FullModel: }, } ) - parameters = ParameterGroup.from_dict( + parameters = Parameters.from_dict( { "k": [0.006, 0.003, 0.0003, 0.03], "loc": [ diff --git a/glotaran/optimization/test/test_constraints.py b/glotaran/optimization/test/test_constraints.py index 5427c0056..87ddb1985 100644 --- a/glotaran/optimization/test/test_constraints.py +++ b/glotaran/optimization/test/test_constraints.py @@ -13,9 +13,9 @@ @pytest.mark.parametrize("link_clp", [True, False]) def test_constraint(index_dependent, link_clp): model = deepcopy(suite.model) - model.dataset_group_models["default"].link_clp = link_clp + model.dataset_groups["default"].link_clp = link_clp model.megacomplex["m1"].is_index_dependent = index_dependent - model.clp_constraints.append(ZeroConstraint.from_dict({"target": "s2"})) + model.clp_constraints.append(ZeroConstraint(**{"target": "s2"})) print("link_clp", link_clp, "index_dependent", index_dependent) dataset = simulate( diff --git a/glotaran/optimization/test/test_data_provider.py b/glotaran/optimization/test/test_data_provider.py index faef6c9f8..961971610 100644 --- a/glotaran/optimization/test/test_data_provider.py +++ b/glotaran/optimization/test/test_data_provider.py @@ -6,7 +6,7 @@ from glotaran.optimization.data_provider import DataProvider from glotaran.optimization.data_provider import DataProviderLinked from glotaran.optimization.test.models import SimpleTestModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme @@ -35,9 +35,9 @@ def dataset_two() -> xr.Dataset: @pytest.fixture() def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: - model = SimpleTestModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}}, + model = SimpleTestModel( + **{ + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { "megacomplex": ["m1"], @@ -51,7 +51,7 @@ def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: print(model.validate()) assert model.valid() - parameters = ParameterGroup.from_list([]) + parameters = Parameters.from_list([]) data = {"dataset1": dataset_one, "dataset2": dataset_two} return Scheme(model, parameters, data, clp_link_tolerance=1) diff --git a/glotaran/optimization/test/test_estimation_provider.py b/glotaran/optimization/test/test_estimation_provider.py index 18945921f..00c93d02c 100644 --- a/glotaran/optimization/test/test_estimation_provider.py +++ b/glotaran/optimization/test/test_estimation_provider.py @@ -9,7 +9,7 @@ from glotaran.optimization.matrix_provider import MatrixProviderLinked from glotaran.optimization.matrix_provider import MatrixProviderUnlinked from glotaran.optimization.test.models import SimpleTestModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme @@ -38,9 +38,9 @@ def dataset_two() -> xr.Dataset: @pytest.fixture() def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: - model = SimpleTestModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}}, + model = SimpleTestModel( + **{ + "megacomplex": {"m1": {"type": "simple-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { "megacomplex": ["m1"], @@ -51,7 +51,7 @@ def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: }, } ) - parameters = ParameterGroup.from_list([]) + parameters = Parameters.from_list([]) data = {"dataset1": dataset_one, "dataset2": dataset_two} return Scheme(model, parameters, data, clp_link_tolerance=1) diff --git a/glotaran/optimization/test/test_matrix_provider.py b/glotaran/optimization/test/test_matrix_provider.py index 95b9fca6f..d0e4e993a 100644 --- a/glotaran/optimization/test/test_matrix_provider.py +++ b/glotaran/optimization/test/test_matrix_provider.py @@ -7,7 +7,7 @@ from glotaran.optimization.matrix_provider import MatrixProviderLinked from glotaran.optimization.matrix_provider import MatrixProviderUnlinked from glotaran.optimization.test.models import SimpleTestModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme @@ -36,9 +36,9 @@ def dataset_two() -> xr.Dataset: @pytest.fixture() def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: - model = SimpleTestModel.from_dict( - { - "megacomplex": {"m1": {"is_index_dependent": False}}, + model = SimpleTestModel( + **{ + "megacomplex": {"m1": {"type": "simple-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { "megacomplex": ["m1"], @@ -49,7 +49,7 @@ def scheme(dataset_one: xr.Dataset, dataset_two: xr.Dataset) -> Scheme: }, } ) - parameters = ParameterGroup.from_list([]) + parameters = Parameters.from_list([]) data = {"dataset1": dataset_one, "dataset2": dataset_two} return Scheme(model, parameters, data, clp_link_tolerance=1) diff --git a/glotaran/optimization/test/test_multiple_goups.py b/glotaran/optimization/test/test_multiple_goups.py index 6f87d57a0..63ce0247b 100644 --- a/glotaran/optimization/test/test_multiple_goups.py +++ b/glotaran/optimization/test/test_multiple_goups.py @@ -2,49 +2,49 @@ from glotaran.optimization.optimize import optimize from glotaran.optimization.test.models import DecayModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate def test_multiple_groups(): - wanted_parameters = ParameterGroup.from_list([101e-4]) - initial_parameters = ParameterGroup.from_list([100e-5]) + wanted_parameters = Parameters.from_list([101e-4]) + initial_parameters = Parameters.from_list([100e-5]) global_axis = np.asarray([1.0]) model_axis = np.arange(0, 150, 1.5) sim_model_dict = { - "megacomplex": {"m1": {"is_index_dependent": False}, "m2": {"type": "global_complex"}}, + "megacomplex": { + "m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}, + "m2": {"type": "simple-spectral-test-mc"}, + }, "dataset": { "dataset1": { - "initial_concentration": [], "megacomplex": ["m1"], "global_megacomplex": ["m2"], "kinetic": ["1"], } }, } - sim_model = DecayModel.from_dict(sim_model_dict) + sim_model = DecayModel(**sim_model_dict) model_dict = { "dataset_groups": {"g1": {}, "g2": {"residual_function": "non_negative_least_squares"}}, - "megacomplex": {"m1": {"is_index_dependent": False}}, + "megacomplex": {"m1": {"type": "simple-kinetic-test-mc", "is_index_dependent": False}}, "dataset": { "dataset1": { "group": "g1", - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1"], }, "dataset2": { "group": "g2", - "initial_concentration": [], "megacomplex": ["m1"], "kinetic": ["1"], }, }, } - model = DecayModel.from_dict(model_dict) + model = DecayModel(**model_dict) dataset = simulate( sim_model, "dataset1", @@ -62,9 +62,9 @@ def test_multiple_groups(): result = optimize(scheme, raise_exception=True) print(result.optimized_parameters) assert result.success - for label, param in result.optimized_parameters.all(): + for param in result.optimized_parameters.all(): if param.vary: - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) for dataset in result.data.values(): assert "weighted_root_mean_square_error" in dataset.attrs diff --git a/glotaran/optimization/test/test_optimization.py b/glotaran/optimization/test/test_optimization.py index 4fe80d983..6b82c44b9 100644 --- a/glotaran/optimization/test/test_optimization.py +++ b/glotaran/optimization/test/test_optimization.py @@ -2,6 +2,8 @@ import pytest import xarray as xr +from glotaran.model.dataset_model import is_dataset_model_index_dependent +from glotaran.model.item import fill_item from glotaran.optimization.optimize import optimize from glotaran.optimization.test.models import SimpleTestModel from glotaran.optimization.test.suites import FullModel @@ -9,7 +11,7 @@ from glotaran.optimization.test.suites import OneCompartmentDecay from glotaran.optimization.test.suites import ThreeDatasetDecay from glotaran.optimization.test.suites import TwoCompartmentDecay -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -56,7 +58,9 @@ def test_optimization(suite, is_index_dependent, link_clp, weight, method): print(model.validate(initial_parameters)) # T201 assert model.valid(initial_parameters) assert ( - model.dataset["dataset1"].fill(model, initial_parameters).is_index_dependent() + is_dataset_model_index_dependent( + fill_item(model.dataset["dataset1"], model, initial_parameters) + ) == is_index_dependent ) @@ -74,7 +78,7 @@ def test_optimization(suite, is_index_dependent, link_clp, weight, method): ) print(f"Dataset {i+1}") # T201 print("=============") # T201 - print(dataset) # T201 + print(dataset.data) # T201 if hasattr(suite, "scale"): dataset["data"] /= suite.scale @@ -97,20 +101,22 @@ def test_optimization(suite, is_index_dependent, link_clp, weight, method): optimization_method=method, ) - model.dataset_group_models["default"].link_clp = link_clp + model.dataset_groups["default"].link_clp = link_clp result = optimize(scheme, raise_exception=True) print(result.optimized_parameters) # T201 + print(result.data["dataset1"].fitted_data) # T201 assert result.success optimized_scheme = result.get_scheme() + assert result.optimized_parameters != initial_parameters assert result.optimized_parameters == optimized_scheme.parameters for dataset in optimized_scheme.data.values(): assert "fitted_data" not in dataset if weight: assert "weight" in dataset - for label, param in result.optimized_parameters.all(): + for param in result.optimized_parameters.all(): if param.vary: - assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, wanted_parameters.get(param.label).value, rtol=1e-1) for i, dataset in enumerate(data.values()): resultdata = result.data[f"dataset{i+1}"] @@ -161,9 +167,9 @@ def test_optimization_full_model(index_dependent): result_data = result.data["dataset1"] assert "fitted_data" in result_data - for label, param in result.optimized_parameters.all(): + for param in result.optimized_parameters.all(): if param.vary: - assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1) + assert np.allclose(param.value, parameters.get(param.label).value, rtol=1e-1) clp = result_data.clp print(clp) # T201 @@ -182,12 +188,8 @@ def test_result_data(model_weight: bool, index_dependent: bool): ).to_dataset(name="data") model_dict = { - "megacomplex": {"m1": {"is_index_dependent": index_dependent}}, - "dataset": { - "dataset1": { - "megacomplex": ["m1"], - }, - }, + "megacomplex": {"m1": {"type": "simple-test-mc", "is_index_dependent": index_dependent}}, + "dataset": {"dataset1": {"megacomplex": ["m1"]}}, } if model_weight: @@ -195,9 +197,9 @@ def test_result_data(model_weight: bool, index_dependent: bool): else: data["weight"] = xr.ones_like(data.data) * 0.5 - model = SimpleTestModel.from_dict(model_dict) + model = SimpleTestModel(**model_dict) assert model.valid() - parameters = ParameterGroup.from_list([1]) + parameters = Parameters.from_list([1]) scheme = Scheme(model, parameters, {"dataset1": data}, maximum_number_function_evaluations=1) result = optimize(scheme, raise_exception=True) diff --git a/glotaran/optimization/test/test_penalties.py b/glotaran/optimization/test/test_penalties.py index ad8e297be..f2bc1e23a 100644 --- a/glotaran/optimization/test/test_penalties.py +++ b/glotaran/optimization/test/test_penalties.py @@ -6,7 +6,7 @@ from glotaran.model import EqualAreaPenalty from glotaran.optimization.optimization_group import OptimizationGroup from glotaran.optimization.test.suites import TwoCompartmentDecay as suite -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -15,11 +15,11 @@ @pytest.mark.parametrize("link_clp", [True, False]) def test_penalties(index_dependent, link_clp): model = deepcopy(suite.model) - model.dataset_group_models["default"].link_clp = link_clp + model.dataset_groups["default"].link_clp = link_clp model.megacomplex["m1"].is_index_dependent = index_dependent - model.clp_area_penalties.append( - EqualAreaPenalty.from_dict( - { + model.clp_penalties.append( + EqualAreaPenalty( + **{ "source": "s1", "source_intervals": [(1, 20)], "target": "s2", @@ -29,7 +29,7 @@ def test_penalties(index_dependent, link_clp): } ) ) - parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) + parameters = Parameters.from_list([11e-4, 22e-5, 2]) global_axis = np.arange(50) diff --git a/glotaran/optimization/test/test_relations.py b/glotaran/optimization/test/test_relations.py index 6dc3be49a..270130077 100644 --- a/glotaran/optimization/test/test_relations.py +++ b/glotaran/optimization/test/test_relations.py @@ -2,10 +2,10 @@ import pytest -from glotaran.model import Relation +from glotaran.model import ClpRelation from glotaran.optimization.optimization_group import OptimizationGroup from glotaran.optimization.test.suites import TwoCompartmentDecay as suite -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project import Scheme from glotaran.simulation import simulate @@ -14,12 +14,10 @@ @pytest.mark.parametrize("link_clp", [True, False]) def test_relations(index_dependent, link_clp): model = deepcopy(suite.model) - model.dataset_group_models["default"].link_clp = link_clp + model.dataset_groups["default"].link_clp = link_clp model.megacomplex["m1"].is_index_dependent = index_dependent - model.clp_relations.append( - Relation.from_dict({"source": "s1", "target": "s2", "parameter": "3"}) - ) - parameters = ParameterGroup.from_list([11e-4, 22e-5, 2]) + model.clp_relations.append(ClpRelation(**{"source": "s1", "target": "s2", "parameter": "3"})) + parameters = Parameters.from_list([11e-4, 22e-5, 2]) print("link_clp", link_clp, "index_dependent", index_dependent) # T201 dataset = simulate( diff --git a/glotaran/parameter/__init__.py b/glotaran/parameter/__init__.py index 9eb44b67a..3e08f5cc8 100644 --- a/glotaran/parameter/__init__.py +++ b/glotaran/parameter/__init__.py @@ -1,4 +1,17 @@ """The glotaran parameter package.""" from glotaran.parameter.parameter import Parameter -from glotaran.parameter.parameter_group import ParameterGroup from glotaran.parameter.parameter_history import ParameterHistory +from glotaran.parameter.parameters import Parameters + + +def __getattr__(attribute_name: str): + from glotaran.deprecation import deprecate_module_attribute + + if attribute_name == "ParameterGroup": + return deprecate_module_attribute( + deprecated_qual_name="glotaran.parameter.ParameterGroup", + new_qual_name="glotaran.parameter.Parameters", + to_be_removed_in_version="0.8.0", + ) + + raise AttributeError(f"module {__name__} has no attribute {attribute_name}") diff --git a/glotaran/parameter/parameter.py b/glotaran/parameter/parameter.py index 9bf3e9e42..34e5ca950 100644 --- a/glotaran/parameter/parameter.py +++ b/glotaran/parameter/parameter.py @@ -4,9 +4,18 @@ import re from typing import TYPE_CHECKING +from typing import Any import asteval import numpy as np +from attr import ib +from attrs import Attribute +from attrs import asdict +from attrs import define +from attrs import evolve +from attrs import fields +from attrs import filters +from attrs import validators try: from numpy._typing._array_like import _SupportsArray @@ -14,27 +23,62 @@ # numpy < 1.23 from numpy.typing._array_like import _SupportsArray +from glotaran.utils.attrs_helper import no_default_vals_in_repr +from glotaran.utils.helpers import nan_or_equal from glotaran.utils.ipython import MarkdownStr from glotaran.utils.sanitize import pretty_format_numerical from glotaran.utils.sanitize import sanitize_parameter_list if TYPE_CHECKING: - from typing import Any - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters -RESERVED_LABELS: list[str] = list(asteval.make_symbol_table().keys()) + ["group", "iteration"] +RESERVED_LABELS: list[str] = list(asteval.make_symbol_table().keys()) + ["parameters", "iteration"] -class Keys: - """Keys for parameter options.""" +OPTION_NAMES_SERIALIZED = { + "expression": "expr", + "maximum": "max", + "minimum": "min", + "non_negative": "non-negative", + "standard_error": "standard-error", +} - EXPR = "expr" - MAX = "max" - MIN = "min" - NON_NEG = "non-negative" - STD_ERR = "standard-error" - VARY = "vary" +OPTION_NAMES_DESERIALIZED = {v: k for k, v in OPTION_NAMES_SERIALIZED.items()} + + +def deserialize_options(options: dict[str, Any]) -> dict[str, Any]: + """Replace options keys in serialized format by attribute names. + + Parameters + ---------- + options : dict[str, Any] + The serialized options. + + Returns + ------- + dict[str, Any] + The deserialized options. + + """ + return {OPTION_NAMES_DESERIALIZED.get(k, k): v for k, v in options.items()} + + +def serialize_options(options: dict[str, Any]) -> dict[str, Any]: + """Replace options keys with serialized format by attribute names. + + Parameters + ---------- + options : dict[str, Any] + The options. + + Returns + ------- + dict[str, Any] + The serialized options. + + """ + return {OPTION_NAMES_SERIALIZED.get(k, k): v for k, v in options.items()} PARAMETER_EXPRESSION_REGEX = re.compile(r"\$(?P[\w\d\.]+)((?![\w\d\.]+)|$)") @@ -43,439 +87,175 @@ class Keys: """A regular expression to validate labels.""" -class Parameter(_SupportsArray): - """A parameter for optimization.""" - - def __init__( - self, - label: str = None, - full_label: str = None, - expression: str | None = None, - maximum: float = np.inf, - minimum: float = -np.inf, - non_negative: bool = False, - standard_error: float = np.nan, - value: float = np.nan, - vary: bool = True, - ): - """Optimization Parameter supporting numpy array operations. - - Parameters - ---------- - label : str - The label of the parameter., by default None - full_label : str - The label of the parameter with its path in a parameter group prepended. - , by default None - expression : str | None - Expression to calculate the parameters value from, - e.g. if used in relation to another parameter. , by default None - maximum : float - Upper boundary for the parameter to be varied to., by default np.inf - minimum : float - Lower boundary for the parameter to be varied to., by default -np.inf - non_negative : bool - Whether the parameter should always be bigger than zero., by default False - standard_error: float - The standard error of the parameter. , by default ``np.nan`` - value : float - Value of the parameter, by default np.nan - vary : bool - Whether the parameter should be changed during optimization or not. - , by default True - """ - self.label = label - self.full_label = full_label or "" - self.expression = expression - self.maximum = maximum - self.minimum = minimum - self.non_negative = non_negative - self.standard_error = standard_error - self.value = value - self.vary = vary +def valid_label(parameter: Parameter, attribute: Attribute, label: str): + """Check if a label is a valid label for :class:`Parameter`. - self._transformed_expression: str | None = None + Parameters + ---------- + parameter : Parameter + The :class:`Parameter` instance + attribute : Attribute + The label field. + label : str + The label value. + + Raises + ------ + ValueError + Raise when the label is not valid. + """ + if VALID_LABEL_REGEX.search(label.replace(".", "_")) is not None or label in RESERVED_LABELS: + raise ValueError(f"'{label}' is not a valid parameter label.") - @staticmethod - def create_default_list(label: str) -> list: - """Create a default list for use with :method:`Parameter.from_list_or_value`. - Intended for parameter generation. +def set_transformed_expression(parameter: Parameter, attribute: Attribute, expression: str | None): + """Set the transformed expression from an expression. - Parameters - ---------- - label : str - The label of the parameter. + Parameters + ---------- + parameter : Parameter + The :class:`Parameter` instance + attribute : Attribute + The label field. + expression : str | None + The expression value. + """ + if expression: + parameter.vary = False + parameter.transformed_expression = PARAMETER_EXPRESSION_REGEX.sub( + r"parameters.get('\g').value", expression + ) - Returns - ------- - list - The list with default values. - See Also - -------- - :method:`Model.generate_parameters` +@no_default_vals_in_repr +@define +class Parameter(_SupportsArray): + """A parameter for optimization.""" - """ - return [ - label, - 0.0, - { - Keys.EXPR: None, - Keys.MAX: np.inf, - Keys.MIN: -np.inf, - Keys.NON_NEG: False, - Keys.VARY: True, - }, - ] - - @staticmethod - def valid_label(label: str) -> bool: - """Check if a label is a valid label for :class:`Parameter`. + label: str = ib(converter=str, validator=[valid_label]) + value: float = ib( + default=np.nan, + converter=lambda v: float(v) if isinstance(v, int) else v, + validator=[validators.instance_of(float)], + ) + standard_error: float = np.nan + expression: str | None = ib(default=None, validator=[set_transformed_expression]) + maximum: float = ib(default=np.inf, validator=[validators.instance_of((int, float))]) + minimum: float = ib(default=-np.inf, validator=[validators.instance_of((int, float))]) + non_negative: bool = False + vary: bool = ib(default=True) + + transformed_expression: str | None = ib(default=None, init=False, repr=False) - Parameters - ---------- - label : str - The label to validate. + @property + def label_short(self) -> str: + """Get short label. Returns ------- - bool - Whether the label is valid. - + str : + The short label. """ - return VALID_LABEL_REGEX.search(label) is None and label not in RESERVED_LABELS + return self.label.split(".")[-1] @classmethod - def from_list_or_value( + def from_list( cls, - value: int | float | list, + values: list[Any], + *, default_options: dict[str, Any] | None = None, - label: str = None, ) -> Parameter: - """Create a parameter from a list or numeric value. + """Create a parameter from a list. Parameters ---------- - value : int | float | list - The list or numeric value. - default_options : dict[str, Any]|None + values : list[Any] + The list of parameter definitions. + default_options : dict[str, Any] | None A dictionary of default options. - label : str - The label of the parameter. Returns ------- Parameter The created :class:`Parameter`. """ - param = cls(label=label) options = None - if not isinstance(value, list): - param.value = value - - else: - values = sanitize_parameter_list(value) - param.label = _retrieve_item_from_list_by_type(values, str, label) - param.value = float(_retrieve_item_from_list_by_type(values, (int, float), np.nan)) - options = _retrieve_item_from_list_by_type(values, dict, None) + values = sanitize_parameter_list(values.copy()) + param = { + "label": _retrieve_item_from_list_by_type(values, str, ""), + "value": _retrieve_item_from_list_by_type(values, (int, float), np.nan), + } + options = _retrieve_item_from_list_by_type(values, dict, {}) if default_options: - param._set_options_from_dict(default_options) - - if options: - param._set_options_from_dict(options) - return param + param |= deserialize_options(default_options) + param |= deserialize_options(options) - @classmethod - def from_dict(cls, parameter_dict: dict[str, Any]) -> Parameter: - """Create a :class:`Parameter` from a dictionary. + return cls(**param) - Expects a dictionary created by :method:`Parameter.as_dict`. - - Parameters - ---------- - parameter_dict : dict[str, Any] - The source dictionary. + def copy(self) -> Parameter: + """Create a copy of the :class:`Parameter`. Returns ------- - Parameter - The created :class:`Parameter` + Parameter : + A copy of the :class:`Parameter`. """ - parameter_dict = {k.replace("-", "_"): v for k, v in parameter_dict.items()} - parameter_dict["full_label"] = parameter_dict["label"] - parameter_dict["label"] = parameter_dict["label"].split(".")[-1] - return cls(**parameter_dict) + return evolve(self) - def as_dict(self, as_optimized: bool = True) -> dict[str, Any]: - """Create a dictionary containing the parameter properties. - - Note: - ----- - Intended for internal use. - - Parameters - ---------- - as_optimized : bool - Whether to include properties which are the result of optimization. + def as_dict(self) -> dict[str, Any]: + """Get the parameter as a dictionary. Returns ------- dict[str, Any] - The created dictionary. + The parameter as dictionary. """ - parameter_dict = { - "label": self.full_label, - "value": self.value, - "expression": self.expression, - "minimum": self.minimum, - "maximum": self.maximum, - "non-negative": self.non_negative, - "vary": self.vary, - } - if as_optimized: - parameter_dict["standard-error"] = self.standard_error - return parameter_dict + return asdict(self, filter=filters.exclude(fields(Parameter).transformed_expression)) - def set_from_group(self, group: ParameterGroup): - """Set values of the parameter to the values of the corresponding parameter in the group. + def _deep_equals(self, other: Parameter) -> bool: + """Compare all attributes for equality not only ``value`` like ``__eq__`` does. - Notes - ----- - For internal use. + This is used by ``Parameters`` to check for equality. Parameters ---------- - group : ParameterGroup - The :class:`glotaran.parameter.ParameterGroup`. - """ - p = group.get(self.full_label) - self.expression = p.expression - self.maximum = p.maximum - self.minimum = p.minimum - self.non_negative = p.non_negative - self.standard_error = p.standard_error - self.value = p.value - self.vary = p.vary - - def _set_options_from_dict(self, options: dict[str, Any]): - """Set the parameter's options from a dictionary. - - Parameters - ---------- - options : dict[str, Any] - A dictionary containing parameter options. - """ - if Keys.EXPR in options: - self.expression = options[Keys.EXPR] - if Keys.NON_NEG in options: - self.non_negative = options[Keys.NON_NEG] - if Keys.MAX in options: - self.maximum = options[Keys.MAX] - if Keys.MIN in options: - self.minimum = options[Keys.MIN] - if Keys.VARY in options: - self.vary = options[Keys.VARY] - if Keys.STD_ERR in options: - self.standard_error = options[Keys.STD_ERR] - - @property - def label(self) -> str | None: - """Label of the parameter. - - Returns - ------- - str - The label. - """ - return self._label - - @label.setter - def label(self, label: str | None): - # ensure that label is str | None even if an int is passed - label = None if label is None else str(label) - if label is not None and not Parameter.valid_label(label): - raise ValueError(f"'{label}' is not a valid group label.") - self._label = label - - @property - def full_label(self) -> str: - """Label of the parameter with its path in a parameter group prepended. - - Returns - ------- - str - The full label. - """ - return self._full_label - - @full_label.setter - def full_label(self, full_label: str): - self._full_label = full_label - - @property - def non_negative(self) -> bool: - r"""Indicate if the parameter is non-negative. - - If true, the parameter will be transformed with :math:`p' = \log{p}` and - :math:`p = \exp{p'}`. - - Notes - ----- - Always `False` if `expression` is not `None`. + other: Parameter + Other parameter to compare against. Returns ------- bool - Whether the parameter is non-negative. - """ - return self._non_negative if self.expression is None else False - - @non_negative.setter - def non_negative(self, non_negative: bool): - self._non_negative = non_negative - - @property - def vary(self) -> bool: - """Indicate if the parameter should be optimized. - - Notes - ----- - Always `False` if `expression` is not `None`. - - Returns - ------- - bool - Whether the parameter should be optimized. - """ - return self._vary if self.expression is None else False - - @vary.setter - def vary(self, vary: bool): - self._vary = vary - - @property - def maximum(self) -> float: - """Upper bound of the parameter. - - Returns - ------- - float - The upper bound of the parameter. - """ - return self._maximum - - @maximum.setter - def maximum(self, maximum: int | float): - if not isinstance(maximum, float): - try: - maximum = float(maximum) - except Exception: - raise TypeError( - "Parameter maximum must be numeric." - + f"'{self.full_label}' has maximum '{maximum}' of type '{type(maximum)}'" - ) - - self._maximum = maximum - - @property - def minimum(self) -> float: - """Lower bound of the parameter. - - Returns - ------- - float - - The lower bound of the parameter. + Whether or not all attributes are equal. """ - return self._minimum - - @minimum.setter - def minimum(self, minimum: int | float): - if not isinstance(minimum, float): - try: - minimum = float(minimum) - except Exception: - raise TypeError( - "Parameter minimum must be numeric." - + f"'{self.full_label}' has minimum '{minimum}' of type '{type(minimum)}'" - ) - - self._minimum = minimum + return all( + nan_or_equal(self_val, other_val) + for self_val, other_val in zip(self.as_dict().values(), other.as_dict().values()) + ) - @property - def expression(self) -> str | None: - """Expression to calculate the parameters value from. + def as_list(self, label_short: bool = False) -> list[str | float | dict[str, Any]]: + """Get the parameter as a dictionary. - This can used to set a relation to another parameter. + Parameters + ---------- + label_short : bool + If true, the label will be replaced by the shortened label. Returns ------- - str | None - The expression. + dict[str, Any] + The parameter as dictionary. """ - return self._expression + options = self.as_dict() - @expression.setter - def expression(self, expression: str | None): - self._expression = expression - self._transformed_expression = None + label = options.pop("label") + value = options.pop("value") - @property - def transformed_expression(self) -> str | None: - """Expression of the parameter transformed for evaluation within a `ParameterGroup`. + if label_short: + label = self.label_short - Returns - ------- - str | None - The transformed expression. - """ - if self.expression is not None and self._transformed_expression is None: - self._transformed_expression = PARAMETER_EXPRESSION_REGEX.sub( - r"group.get('\g').value", self.expression - ) - return self._transformed_expression - - @property - def standard_error(self) -> float: - """Standard error of the optimized parameter. - - Returns - ------- - float - The standard error of the parameter. - """ # noqa: D401 - return self._stderr - - @standard_error.setter - def standard_error(self, standard_error: float): - self._stderr = standard_error - - @property - def value(self) -> float: - """Value of the parameter. - - Returns - ------- - float - The value of the parameter. - """ - return self._value - - @value.setter - def value(self, value: int | float): - if not isinstance(value, float) and value is not np.nan: - try: - value = float(value) - except Exception: - raise TypeError( - "Parameter value must be numeric." - + f"'{self.full_label}' has value '{value}' of type '{type(value)}'" - ) - - self._value = value + return [label, value, serialize_options(options)] def get_value_and_bounds_for_optimization(self) -> tuple[float, float, float]: """Get the parameter value and bounds with expression and non-negative constraints applied. @@ -508,16 +288,16 @@ def set_value_from_optimization(self, value: float): def markdown( self, - all_parameters: ParameterGroup | None = None, - initial_parameters: ParameterGroup | None = None, + all_parameters: Parameters | None = None, + initial_parameters: Parameters | None = None, ) -> MarkdownStr: """Get a markdown representation of the parameter. Parameters ---------- - all_parameters : ParameterGroup | None + all_parameters : Parameters | None A parameter group containing the whole parameter set (used for expression lookup). - initial_parameters : ParameterGroup | None + initial_parameters : Parameters | None The initial parameter. Returns @@ -525,9 +305,9 @@ def markdown( MarkdownStr The parameter as markdown string. """ - md = f"{self.full_label}" + md = f"{self.label}" - parameter = self if all_parameters is None else all_parameters.get(self.full_label) + parameter = self if all_parameters is None else all_parameters.get(self.label) value = f"{parameter.value:.2e}" if parameter.vary: if parameter.standard_error is not np.nan: @@ -535,7 +315,7 @@ def markdown( value += f"Β±{parameter.standard_error:.2e}, t-value: {t_value}" if initial_parameters is not None: - initial_value = initial_parameters.get(parameter.full_label).value + initial_value = initial_parameters.get(parameter.label).value value += f", initial: {initial_value:.2e}" md += f"({value})" elif parameter.expression is not None: @@ -554,44 +334,9 @@ def markdown( return MarkdownStr(md) - def __getstate__(self): - """Get state for pickle.""" - return ( - self.label, - self.full_label, - self.expression, - self.maximum, - self.minimum, - self.non_negative, - self.standard_error, - self.value, - self.vary, - ) - - def __setstate__(self, state): - """Set state from pickle.""" - ( - self.label, - self.full_label, - self.expression, - self.maximum, - self.minimum, - self.non_negative, - self.standard_error, - self.value, - self.vary, - ) = state - - def __repr__(self): - """Representation used by repl and tracebacks.""" - return ( - f"{type(self).__name__}(label={self.label!r}, value={self.value!r}," - f" expression={self.expression!r}, vary={self.vary!r})" - ) - def __array__(self): """array""" # noqa: D400, D403 - return np.array(float(self._value), dtype=float) + return np.array(self.value, dtype=float) def __str__(self) -> str: """Representation used by print and str.""" @@ -603,115 +348,115 @@ def __str__(self) -> str: def __abs__(self): """abs""" # noqa: D400, D403 - return abs(self._value) + return abs(self.value) def __neg__(self): """neg""" # noqa: D400, D403 - return -self._value + return -self.value def __pos__(self): """positive""" # noqa: D400, D403 - return +self._value + return +self.value def __int__(self): """int""" # noqa: D400, D403 - return int(self._value) + return int(self.value) def __float__(self): """float""" # noqa: D400, D403 - return float(self._value) + return float(self.value) def __trunc__(self): """trunc""" # noqa: D400, D403 - return self._value.__trunc__() + return self.value.__trunc__() def __add__(self, other): """+""" # noqa: D400 - return self._value + other + return self.value + other def __sub__(self, other): """-""" # noqa: D400 - return self._value - other + return self.value - other def __truediv__(self, other): """/""" # noqa: D400 - return self._value / other + return self.value / other def __floordiv__(self, other): """//""" # noqa: D400 - return self._value // other + return self.value // other def __divmod__(self, other): """divmod""" # noqa: D400, D403 - return divmod(self._value, other) + return divmod(self.value, other) def __mod__(self, other): """%""" # noqa: D400 - return self._value % other + return self.value % other def __mul__(self, other): """*""" # noqa: D400 - return self._value * other + return self.value * other def __pow__(self, other): """**""" # noqa: D400 - return self._value**other + return self.value**other def __gt__(self, other): """>""" # noqa: D400 - return self._value > other + return self.value > other def __ge__(self, other): """>=""" # noqa: D400 - return self._value >= other + return self.value >= other def __le__(self, other): """<=""" # noqa: D400 - return self._value <= other + return self.value <= other def __lt__(self, other): """<""" # noqa: D400 - return self._value < other + return self.value < other def __eq__(self, other): """==""" # noqa: D400 - return self._value == other + return self.value == other def __ne__(self, other): """!=""" # noqa: D400 - return self._value != other + return self.value != other def __radd__(self, other): """+ (right)""" # noqa: D400 - return other + self._value + return other + self.value def __rtruediv__(self, other): """/ (right)""" # noqa: D400 - return other / self._value + return other / self.value def __rdivmod__(self, other): """divmod (right)""" # noqa: D400, D403 - return divmod(other, self._value) + return divmod(other, self.value) def __rfloordiv__(self, other): """// (right)""" # noqa: D400 - return other // self._value + return other // self.value def __rmod__(self, other): """% (right)""" # noqa: D400 - return other % self._value + return other % self.value def __rmul__(self, other): """* (right)""" # noqa: D400 - return other * self._value + return other * self.value def __rpow__(self, other): """** (right)""" # noqa: D400 - return other**self._value + return other**self.value def __rsub__(self, other): """- (right)""" # noqa: D400 - return other - self._value + return other - self.value def _log_value(value: float) -> float: diff --git a/glotaran/parameter/parameter_group.py b/glotaran/parameter/parameter_group.py deleted file mode 100644 index 03dbd4ce9..000000000 --- a/glotaran/parameter/parameter_group.py +++ /dev/null @@ -1,700 +0,0 @@ -"""The parameter group class.""" - -from __future__ import annotations - -import contextlib -from copy import copy -from textwrap import indent -from typing import TYPE_CHECKING -from typing import Any -from typing import Generator - -import asteval -import numpy as np -import pandas as pd -from tabulate import tabulate - -from glotaran.deprecation import deprecate -from glotaran.io import load_parameters -from glotaran.io import save_parameters -from glotaran.parameter.parameter import Parameter -from glotaran.utils.ipython import MarkdownStr -from glotaran.utils.sanitize import pretty_format_numerical - -if TYPE_CHECKING: - from glotaran.parameter.parameter_history import ParameterHistory - - -class ParameterNotFoundException(Exception): - """Raised when a Parameter is not found in the Group.""" - - def __init__(self, path, label): # noqa: D107 - super().__init__(f"Cannot find parameter {'.'.join(path+[label])!r}") - - -class ParameterGroup(dict): - """Represents are group of parameters. - - Can contain other groups, creating a tree-like hierarchy. - """ - - loader = load_parameters - - def __init__(self, label: str = None, root_group: ParameterGroup = None): - """Initialize a :class:`ParameterGroup` instance with ``label``. - - Parameters - ---------- - label : str - The label of the group. - root_group : ParameterGroup - The root group - - Raises - ------ - ValueError - Raised if the an invalid label is given. - """ - if label is not None and not Parameter.valid_label(label): - raise ValueError(f"'{label}' is not a valid group label.") - self._label = label - self._parameters: dict[str, Parameter] = {} - self._root_group = root_group - self._evaluator = ( - asteval.Interpreter(symtable=asteval.make_symbol_table(group=self)) - if root_group is None - else None - ) - self.source_path = "parameters.csv" - super().__init__() - - @classmethod - def from_dict( - cls, - parameter_dict: dict[str, dict[str, Any] | list[float | list[Any]]], - label: str = None, - root_group: ParameterGroup = None, - ) -> ParameterGroup: - """Create a :class:`ParameterGroup` from a dictionary. - - Parameters - ---------- - parameter_dict : dict[str, dict | list] - A parameter dictionary containing parameters. - label : str - The label of the group. - root_group : ParameterGroup - The root group - - Returns - ------- - ParameterGroup - The created :class:`ParameterGroup` - """ - root = cls(label=label, root_group=root_group) - for label, item in parameter_dict.items(): - label = str(label) - if isinstance(item, dict): - root.add_group(cls.from_dict(item, label=label, root_group=root)) - if isinstance(item, list): - root.add_group(cls.from_list(item, label=label, root_group=root)) - if root_group is None: - root.update_parameter_expression() - return root - - @classmethod - def from_list( - cls, - parameter_list: list[float | list[Any]], - label: str = None, - root_group: ParameterGroup = None, - ) -> ParameterGroup: - """Create a :class:`ParameterGroup` from a list. - - Parameters - ---------- - parameter_list : list[float | list[Any]] - A parameter list containing parameters - label : str - The label of the group. - root_group : ParameterGroup - The root group - - Returns - ------- - ParameterGroup - The created :class:`ParameterGroup`. - """ - root = cls(label=label, root_group=root_group) - - defaults: dict[str, Any] | None = next( - (item for item in parameter_list if isinstance(item, dict)), None # type:ignore[misc] - ) - - for i, item in enumerate(parameter_list): - if isinstance(item, (str, int, float)): - with contextlib.suppress(ValueError): - item = float(item) - if isinstance(item, (float, int, list)): - root.add_parameter( - Parameter.from_list_or_value(item, label=str(i + 1), default_options=defaults) - ) - if root_group is None: - root.update_parameter_expression() - return root - - @classmethod - def from_parameter_dict_list(cls, parameter_dict_list: list[dict[str, Any]]) -> ParameterGroup: - """Create a :class:`ParameterGroup` from a list of parameter dictionaries. - - Parameters - ---------- - parameter_dict_list : list[dict[str, Any]] - A list of parameter dictionaries. - - Returns - ------- - ParameterGroup - The created :class:`ParameterGroup`. - """ - parameter_group = cls() - for parameter_dict in parameter_dict_list: - group = parameter_group.get_group_for_parameter_by_label( - parameter_dict["label"], create_if_not_exist=True - ) - group.add_parameter(Parameter.from_dict(parameter_dict)) - parameter_group.update_parameter_expression() - return parameter_group - - @classmethod - def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> ParameterGroup: - """Create a :class:`ParameterGroup` from a :class:`pandas.DataFrame`. - - Parameters - ---------- - df : pd.DataFrame - The source data frame. - source : str - Optional name of the source file, used for error messages. - - Returns - ------- - ParameterGroup - The created parameter group. - - Raises - ------ - ValueError - Raised if the columns 'label' or 'value' doesn't exist. Also raised if the columns - 'minimum', 'maximum' or 'values' contain non numeric values or if the columns - 'non-negative' or 'vary' are no boolean. - """ - for column_name in ["label", "value"]: - if column_name not in df: - raise ValueError(f"Missing column '{column_name}' in '{source}'") - - for column_name in ["minimum", "maximum", "value"]: - if column_name in df and any(not np.isreal(v) for v in df[column_name]): - raise ValueError(f"Column '{column_name}' in '{source}' has non numeric values") - - for column_name in ["non-negative", "vary"]: - if column_name in df and any(not isinstance(v, bool) for v in df[column_name]): - raise ValueError(f"Column '{column_name}' in '{source}' has non boolean values") - - # clean NaN if expressions - if "expression" in df: - expressions = df["expression"].to_list() - df["expression"] = [expr if isinstance(expr, str) else None for expr in expressions] - return cls.from_parameter_dict_list(df.to_dict(orient="records")) - - @property - def label(self) -> str | None: - """Label of the group. - - Returns - ------- - str - The label of the group. - """ - return self._label - - @property - def root_group(self) -> ParameterGroup | None: - """Root of the group. - - Returns - ------- - ParameterGroup - The root group. - """ - return self._root_group - - def to_parameter_dict_list(self, as_optimized: bool = True) -> list[dict[str, Any]]: - """Create list of parameter dictionaries from the group. - - Parameters - ---------- - as_optimized : bool - Whether to include properties which are the result of optimization. - - Returns - ------- - list[dict[str, Any]] - Alist of parameter dictionaries. - """ - return [p[1].as_dict(as_optimized=as_optimized) for p in self.all()] - - def to_dataframe(self, as_optimized: bool = True) -> pd.DataFrame: - """Create a pandas data frame from the group. - - Parameters - ---------- - as_optimized : bool - Whether to include properties which are the result of optimization. - - Returns - ------- - pd.DataFrame - The created data frame. - """ - return pd.DataFrame(self.to_parameter_dict_list(as_optimized=as_optimized)) - - def get_group_for_parameter_by_label( - self, parameter_label: str, create_if_not_exist: bool = False - ) -> ParameterGroup: - """Get the group for a parameter by it's label. - - Parameters - ---------- - parameter_label : str - The parameter label. - create_if_not_exist : bool - Create the parameter group if not existent. - - Returns - ------- - ParameterGroup - The group of the parameter. - - Raises - ------ - KeyError - Raised if the group does not exist and `create_if_not_exist` is `False`. - """ - path = parameter_label.split(".") - group = self - while len(path) > 1: - group_label = path.pop(0) - if group_label not in group: - if create_if_not_exist: - group.add_group(ParameterGroup(label=group_label, root_group=group)) - else: - raise KeyError(f"Subgroup '{group_label}' does not exist.") - group = group[group_label] - return group - - @deprecate( - deprecated_qual_name_usage=( - "glotaran.parameter.ParameterGroup.to_csv(file_name=)" - ), - new_qual_name_usage=( - "glotaran.io.save_parameters(parameters, " - 'file_name=, format_name="csv")' - ), - to_be_removed_in_version="0.7.0", - importable_indices=(2, 1), - ) - def to_csv(self, filename: str, delimiter: str = ",") -> None: - """Save a :class:`ParameterGroup` to a CSV file. - - Warning - ------- - Deprecated use - ``glotaran.io.save_parameters(parameters, file_name=, format_name="csv")`` - instead. - - Parameters - ---------- - filename : str - File to write the parameter specs to. - delimiter : str - Character to separate columns., by default "," - """ - save_parameters( - self, - file_name=filename, - allow_overwrite=True, - sep=delimiter, - replace_infinfinity=False, - ) - - def add_parameter(self, parameter: Parameter | list[Parameter]): - """Add a :class:`Parameter` to the group. - - Parameters - ---------- - parameter : Parameter | list[Parameter] - The parameter to add. - - Raises - ------ - TypeError - If ``parameter`` or any item of it is not an instance of :class:`Parameter`. - """ - if not isinstance(parameter, list): - parameter = [parameter] - if any(not isinstance(p, Parameter) for p in parameter): - raise TypeError("Parameter must be instance of glotaran.parameter.Parameter") - for p in parameter: - p.index = len(self._parameters) + 1 - if p.label is None: - p.label = f"{p.index}" - p.full_label = f"{self.label}.{p.label}" if self.label else p.label - self._parameters[p.label] = p - - def add_group(self, group: ParameterGroup): - """Add a :class:`ParameterGroup` to the group. - - Parameters - ---------- - group : ParameterGroup - The group to add. - - Raises - ------ - TypeError - Raised if the group is not an instance of :class:`ParameterGroup`. - """ - if not isinstance(group, ParameterGroup): - raise TypeError("Group must be glotaran.parameter.ParameterGroup") - self[group.label] = group - - def get_nr_roots(self) -> int: - """Return the number of roots of the group. - - Returns - ------- - int - The number of roots. - """ - n = 0 - root = self.root_group - while root is not None: - n += 1 - root = root.root_group - return n - - def groups(self) -> Generator[ParameterGroup, None, None]: - """Return a generator over all groups and their subgroups. - - Yields - ------ - ParameterGroup - A subgroup of :class:`ParameterGroup`. - """ - for group in self: - yield from group.groups() - - def has(self, label: str) -> bool: - """Check if a parameter with the given label is in the group or in a subgroup. - - Parameters - ---------- - label : str - The label of the parameter, with its path in a :class:`ParameterGroup` prepended. - - Returns - ------- - bool - Whether a parameter with the given label exists in the group. - """ - try: - self.get(label) - return True - except Exception: - return False - - def get(self, label: str) -> Parameter: # type:ignore[override] - """Get a :class:`Parameter` by its label. - - Parameters - ---------- - label : str - The label of the parameter, with its path in a :class:`ParameterGroup` prepended. - - Returns - ------- - Parameter - The parameter. - - Raises - ------ - ParameterNotFoundException - Raised if no parameter with the given label exists. - """ - # sometimes the spec parser delivers the labels as int - full_label = str(label) # sourcery skip - - path = full_label.split(".") - label = path.pop() - - # TODO: audit this code - group = self - for element in path: - try: - group = group[element] - except KeyError as e: - raise ParameterNotFoundException(path, label) from e - try: - parameter = group._parameters[label] - parameter.full_label = full_label - return parameter - except KeyError as e: - raise ParameterNotFoundException(path, label) from e - - def copy(self) -> ParameterGroup: - """Create a copy of the :class:`ParameterGroup`. - - Returns - ------- - ParameterGroup : - A copy of the :class:`ParameterGroup`. - - """ - root = ParameterGroup(label=self.label, root_group=self.root_group) - - for label, parameter in self._parameters.items(): - root._parameters[label] = copy(parameter) - - for label, group in self.items(): - root[label] = group.copy() - - return root - - def all( - self, root: str | None = None, separator: str = "." - ) -> Generator[tuple[str, Parameter], None, None]: - """Iterate over all parameter in the group and it's subgroups together with their labels. - - Parameters - ---------- - root : str - The label of the root group - separator : str - The separator for the parameter labels. - - Yields - ------ - tuple[str, Parameter] - A tuple containing the full label of the parameter and the parameter itself. - """ - root = f"{root}{self.label}{separator}" if root is not None else "" - for label, p in self._parameters.items(): - p.full_label = f"{root}{label}" - yield (f"{root}{label}", p) - for _, l in self.items(): - yield from l.all(root=root, separator=separator) - - def get_label_value_and_bounds_arrays( - self, exclude_non_vary: bool = False - ) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray]: - """Return a arrays of all parameter labels, values and bounds. - - Parameters - ---------- - exclude_non_vary: bool - If true, parameters with `vary=False` are excluded. - - Returns - ------- - tuple[list[str], np.ndarray, np.ndarray, np.ndarray] - A tuple containing a list of parameter labels and - an array of the values, lower and upper bounds. - """ - self.update_parameter_expression() - - labels = [] - values = [] - lower_bounds = [] - upper_bounds = [] - - for label, parameter in self.all(): - if not exclude_non_vary or parameter.vary: - labels.append(label) - value, minimum, maximum = parameter.get_value_and_bounds_for_optimization() - values.append(value) - lower_bounds.append(minimum) - upper_bounds.append(maximum) - - return labels, np.asarray(values), np.asarray(lower_bounds), np.asarray(upper_bounds) - - def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray): - """Update the parameter values from a list of labels and values. - - Parameters - ---------- - labels : list[str] - A list of parameter labels. - values : np.ndarray - An array of parameter values. - - Raises - ------ - ValueError - Raised if the size of the labels does not match the stize of values. - """ - if len(labels) != len(values): - raise ValueError( - f"Length of labels({len(labels)}) not equal to length of values({len(values)})." - ) - - for label, value in zip(labels, values): - self.get(label).set_value_from_optimization(value) - - self.update_parameter_expression() - - def set_from_history(self, history: ParameterHistory, index: int): - """Update the :class:`ParameterGroup` with values from a parameter history. - - Parameters - ---------- - history : ParameterHistory - The parameter history. - index : int - The history index. - """ - self.set_from_label_and_value_arrays( - # Omit 0th element with `iteration` label - history.parameter_labels[1:], - history.get_parameters(index)[1:], - ) - - def update_parameter_expression(self): - """Update all parameters which have an expression. - - Raises - ------ - ValueError - Raised if an expression evaluates to a non-numeric value. - """ - for label, parameter in self.all(): - if parameter.expression is not None: - value = self._evaluator(parameter.transformed_expression) - if not isinstance(value, (int, float)): - raise ValueError( - f"Expression '{parameter.expression}' of parameter '{label}' evaluates to " - f"non numeric value '{value}'." - ) - parameter.value = value - - @property - def missing_parameter_value_labels(self) -> list[str]: - """List of full labels where the value is a NaN. - - This property is used to validate that all parameters have starting values. - - Returns - ------- - str - List full labels with missing value. - """ - parameter_df = self.to_dataframe(as_optimized=False) - parameter_nan_value_mask = parameter_df["value"].isna() - return parameter_df[parameter_nan_value_mask]["label"].to_list() - - def markdown(self, float_format: str = ".3e") -> MarkdownStr: - """Format the :class:`ParameterGroup` as markdown string. - - This is done by recursing the nested :class:`ParameterGroup` tree. - - Parameters - ---------- - float_format: str - Format string for floating point numbers, by default ".3e" - - Returns - ------- - MarkdownStr : - The markdown representation as string. - """ - node_indentation = " " * self.get_nr_roots() - return_string = "" - table_header = [ - "_Label_", - "_Value_", - "_Standard Error_", - "_t-value_", - "_Minimum_", - "_Maximum_", - "_Vary_", - "_Non-Negative_", - "_Expression_", - ] - if self.label is not None: - return_string += f"{node_indentation}* __{self.label}__:\n" - if len(self._parameters): - parameter_rows = [ - [ - parameter.label, - parameter.value, - parameter.standard_error, - repr(pretty_format_numerical(parameter.value / parameter.standard_error)), - parameter.minimum, - parameter.maximum, - parameter.vary, - parameter.non_negative, - f"`{parameter.expression}`", - ] - for _, parameter in self._parameters.items() - ] - parameter_table = indent( - tabulate( - parameter_rows, - headers=table_header, - tablefmt="github", - missingval="None", - floatfmt=float_format, - ), - f" {node_indentation}", - ) - return_string += f"\n{parameter_table}\n\n" - for _, child_group in sorted(self.items()): - return_string += f"{child_group.markdown(float_format=float_format)}" - return MarkdownStr(return_string.replace("'", " ")) - - def _repr_markdown_(self) -> str: - """Create a markdown representation. - - Special method used by ``ipython`` to render markdown. - - Returns - ------- - str : - The markdown representation as string. - """ - return str(self.markdown()) - - def __repr__(self) -> str: - """Representation used by repl and tracebacks. - - Returns - ------- - str : - A string representation of the :class:`ParameterGroup`. - """ - parameter_short_notations = [ - [str(parameter.label), parameter.value] for parameter in self._parameters.values() - ] - if self.label is None: - if len(self._parameters) == 0: - return f"{type(self).__name__}.from_dict({super().__repr__()})" - else: - return f"{type(self).__name__}.from_list({parameter_short_notations})" - if len(self._parameters): - return parameter_short_notations.__repr__() - else: - return super().__repr__() - - def __str__(self) -> str: - """Representation used by print and str.""" - return str(self.markdown()) diff --git a/glotaran/parameter/parameter_history.py b/glotaran/parameter/parameter_history.py index b6ce80bb2..81c60d272 100644 --- a/glotaran/parameter/parameter_history.py +++ b/glotaran/parameter/parameter_history.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from glotaran.parameter.parameter_group import ParameterGroup +from glotaran.parameter.parameters import Parameters if TYPE_CHECKING: from os import PathLike @@ -112,7 +112,7 @@ def to_dataframe(self) -> pd.DataFrame: return pd.DataFrame(self._parameters, columns=self.parameter_labels) def to_csv(self, file_name: str | PathLike[str], delimiter: str = ","): - """Write a :class:`ParameterGroup` to a CSV file. + """Write a :class:`ParameterHistory` to a CSV file. Parameters ---------- @@ -124,12 +124,12 @@ def to_csv(self, file_name: str | PathLike[str], delimiter: str = ","): self.source_path = Path(file_name).as_posix() self.to_dataframe().to_csv(file_name, sep=delimiter, index=False) - def append(self, parameter_group: ParameterGroup, current_iteration: int = 0): - """Append a :class:`ParameterGroup` to the history. + def append(self, parameters: Parameters, current_iteration: int = 0): + """Append :class:`Parameters` to the history. Parameters ---------- - parameter_group : ParameterGroup + parameters : Parameters The group to append. current_iteration: int Current iteration of the optimizer. @@ -137,21 +137,19 @@ def append(self, parameter_group: ParameterGroup, current_iteration: int = 0): Raises ------ ValueError - Raised if the parameter labels of the group differs from previous groups. + Raised if the parameter labels differs from previous. """ ( parameter_labels, parameter_values, _, _, - ) = parameter_group.get_label_value_and_bounds_arrays() + ) = parameters.get_label_value_and_bounds_arrays() parameter_labels = ["iteration", *parameter_labels] if len(self._parameter_labels) == 0: self._parameter_labels = parameter_labels if parameter_labels != self.parameter_labels: - raise ValueError( - "Cannot append parameter group. Parameter labels do not match existing." - ) + raise ValueError("Cannot append parameters. Parameter labels do not match existing.") self._parameters.append(np.array([current_iteration, *parameter_values])) diff --git a/glotaran/parameter/parameters.py b/glotaran/parameter/parameters.py new file mode 100644 index 000000000..b24e95c4b --- /dev/null +++ b/glotaran/parameter/parameters.py @@ -0,0 +1,550 @@ +"""The parameters class.""" +from __future__ import annotations + +from textwrap import indent +from typing import TYPE_CHECKING +from typing import Any +from typing import Generator + +import asteval +import numpy as np +import pandas as pd +from tabulate import tabulate + +from glotaran.io import load_parameters +from glotaran.parameter.parameter import Parameter +from glotaran.utils.ipython import MarkdownStr +from glotaran.utils.sanitize import pretty_format_numerical + +if TYPE_CHECKING: + from glotaran.parameter.parameter_history import ParameterHistory + + +class ParameterNotFoundException(Exception): + """Raised when a Parameter is not found.""" + + def __init__(self, label: str): # noqa: D107 + super().__init__(f"Cannot find parameter {label}") + + +class Parameters: + """A container for :class:`Parameter`.""" + + loader = load_parameters + + def __init__(self, parameters: dict[str, Parameter]): + """Create :class:`Parameters`. + + Parameters + ---------- + parameters : dict[str, Parameter] + A parameter list containing parameters + + Returns + ------- + 'Parameters' + The created :class:`Parameters`. + """ + self._parameters: dict[str, Parameter] = parameters + self._evaluator = asteval.Interpreter(symtable=asteval.make_symbol_table(parameters=self)) + self.source_path = "parameters.csv" + self.update_parameter_expression() + + @classmethod + def from_list( + cls, parameter_list: list[float | int | str | list[Any] | dict[str, Any]] + ) -> Parameters: + """Create :class:`Parameters` from a list. + + Parameters + ---------- + parameter_list : list[float | list[Any]] + A parameter list containing parameters + + Returns + ------- + Parameters + The created :class:`Parameters`. + + .. # noqa: D414 + """ + defaults: dict[str, Any] | None = next( + (item for item in parameter_list if isinstance(item, dict)), None + ) + parameters = {} + + for i, item in enumerate(item for item in parameter_list if not isinstance(item, dict)): + if not isinstance(item, list): + item = [item] + if not any(isinstance(v, str) for v in item): + item += [f"{i+1}"] + parameter = Parameter.from_list(item, default_options=defaults) + parameters[parameter.label] = parameter + return cls(parameters) + + @classmethod + def from_dict( + cls, + parameter_dict: dict[str, dict[str, Any] | list[float | list[Any]]], + ) -> Parameters: + """Create a :class:`Parameters` from a dictionary. + + Parameters + ---------- + parameter_dict: dict[str, dict[str, Any] | list[float | list[Any]]] + A parameter dictionary containing parameters. + + Returns + ------- + Parameters + The created :class:`Parameters` + + .. # noqa: D414 + """ + parameters = {} + for label, param_def, default in flatten_parameter_dict(parameter_dict): + parameter = Parameter.from_list(param_def, default_options=default) + label += f".{parameter.label}" + parameter.label = label + parameters[label] = parameter + + return cls(parameters) + + @classmethod + def from_parameter_dict_list(cls, parameter_dict_list: list[dict[str, Any]]) -> Parameters: + """Create :class:`Parameters` from a list of parameter dictionaries. + + Parameters + ---------- + parameter_dict_list : list[dict[str, Any]] + A list of parameter dictionaries. + + Returns + ------- + Parameters + The created :class:`Parameters`. + + .. # noqa: D414 + """ + parameters = {} + for parameter_dict in parameter_dict_list: + parameter = Parameter(**parameter_dict) + parameters[parameter.label] = parameter + return cls(parameters) + + @classmethod + def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> Parameters: + """Create a :class:`Parameters` from a :class:`pandas.DataFrame`. + + Parameters + ---------- + df : pd.DataFrame + The source data frame. + source : str + Optional name of the source file, used for error messages. + + Raises + ------ + ValueError + Raised if the columns 'label' or 'value' doesn't exist. Also raised if the columns + 'minimum', 'maximum' or 'values' contain non numeric values or if the columns + 'non-negative' or 'vary' are no boolean. + + Returns + ------- + Parameters + The created parameter group. + + .. # noqa: D414 + """ + for column_name in ["label", "value"]: + if column_name not in df: + raise ValueError(f"Missing column '{column_name}' in '{source}'") + + for column_name in ["minimum", "maximum", "value"]: + if column_name in df and any(not np.isreal(v) for v in df[column_name]): + raise ValueError(f"Column '{column_name}' in '{source}' has non numeric values") + + for column_name in ["non_negative", "vary"]: + df[column_name] = [v != 0 if isinstance(v, int) else v for v in df[column_name]] + if column_name in df and any(not isinstance(v, bool) for v in df[column_name]): + raise ValueError(f"Column '{column_name}' in '{source}' has non boolean values") + + # clean NaN if expressions + if "expression" in df: + expressions = df["expression"].to_list() + df["expression"] = [expr if isinstance(expr, str) else None for expr in expressions] + return cls.from_parameter_dict_list(df.to_dict(orient="records")) + + def to_dataframe(self) -> pd.DataFrame: + """Create a pandas data frame from the group. + + Returns + ------- + pd.DataFrame + The created data frame. + """ + return pd.DataFrame(self.to_parameter_dict_list()) + + def to_parameter_dict_list(self) -> list[dict[str, Any]]: + """Create list of parameter dictionaries from the group. + + Returns + ------- + list[dict[str, Any]] + A list of parameter dictionaries. + """ + return [p.as_dict() for p in self.all()] + + def to_parameter_dict_or_list(self, serialize_parameters: bool = False) -> dict | list: + """Convert to a dict or list of parameter definitions. + + Parameters + ---------- + serialize_parameters : bool + If true, the parameters will be serialized into list representation. + + Returns + ------- + dict | list + A dict or list of parameter definitions. + """ + if all("." not in p.label for p in self.all()): + return list(self.all()) + parameter_dict: dict[str, Any] = {} + for parameter in self.all(): + path = parameter.label.split(".") + nodes = path[:-2] + node = parameter_dict + for n in nodes: + if n not in node: + node[n] = {} + node = node[n] + upper_node = path[-2] + if upper_node not in node: + node[upper_node] = [] + node[upper_node].append( + parameter.as_list(label_short=True) if serialize_parameters else parameter + ) + return parameter_dict + + def set_from_history(self, history: ParameterHistory, index: int): + """Update the :class:`Parameters` with values from a parameter history. + + Parameters + ---------- + history : ParameterHistory + The parameter history. + index : int + The history index. + """ + self.set_from_label_and_value_arrays( + # Omit 0th element with `iteration` label + history.parameter_labels[1:], + history.get_parameters(index)[1:], + ) + + def copy(self) -> Parameters: + """Create a copy of the :class:`Parameters`. + + Returns + ------- + Parameters : + A copy of the :class:`Parameters`. + + .. # noqa: D414 + """ + return Parameters( + {label: parameter.copy() for label, parameter in self._parameters.items()} + ) + + def all(self) -> Generator[Parameter, None, None]: + """Iterate over all parameters. + + Yields + ------ + Parameter + A parameter in the parameters. + """ + yield from self._parameters.values() + + def has(self, label: str) -> bool: + """Check if a parameter with the given label is in the group or in a subgroup. + + Parameters + ---------- + label : str + The label of the parameter, with its path in a :class:`ParameterGroup` prepended. + + Returns + ------- + bool + Whether a parameter with the given label exists in the group. + """ + return label in self._parameters + + def get(self, label: str) -> Parameter: + """Get a :class:`Parameter` by its label. + + Parameters + ---------- + label : str + The label of the parameter, with its path in a :class:`ParameterGroup` prepended. + + Returns + ------- + Parameter + The parameter. + + Raises + ------ + ParameterNotFoundException + Raised if no parameter with the given label exists. + """ + try: + return self._parameters[label] + except KeyError as error: + raise ParameterNotFoundException(label) from error + + def update_parameter_expression(self): + """Update all parameters which have an expression. + + Raises + ------ + ValueError + Raised if an expression evaluates to a non-numeric value. + """ + for parameter in self.all(): + if parameter.expression is not None: + value = self._evaluator(parameter.transformed_expression) + if not isinstance(value, (int, float)): + raise ValueError( + f"Expression '{parameter.expression}' of parameter '{parameter.label}' " + f"evaluates to non numeric value '{value}'." + ) + parameter.value = value + + def get_label_value_and_bounds_arrays( + self, exclude_non_vary: bool = False + ) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray]: + """Return a arrays of all parameter labels, values and bounds. + + Parameters + ---------- + exclude_non_vary: bool + If true, parameters with `vary=False` are excluded. + + Returns + ------- + tuple[list[str], np.ndarray, np.ndarray, np.ndarray] + A tuple containing a list of parameter labels and + an array of the values, lower and upper bounds. + """ + self.update_parameter_expression() + + labels = [] + values = [] + lower_bounds = [] + upper_bounds = [] + + for parameter in self.all(): + if not exclude_non_vary or parameter.vary: + labels.append(parameter.label) + value, minimum, maximum = parameter.get_value_and_bounds_for_optimization() + values.append(value) + lower_bounds.append(minimum) + upper_bounds.append(maximum) + + return labels, np.asarray(values), np.asarray(lower_bounds), np.asarray(upper_bounds) + + def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray): + """Update the parameter values from a list of labels and values. + + Parameters + ---------- + labels : list[str] + A list of parameter labels. + values : np.ndarray + An array of parameter values. + + Raises + ------ + ValueError + Raised if the size of the labels does not match the stize of values. + """ + if len(labels) != len(values): + raise ValueError( + f"Length of labels({len(labels)}) not equal to length of values({len(values)})." + ) + + for label, value in zip(labels, values): + self.get(label).set_value_from_optimization(value) + + self.update_parameter_expression() + + def markdown(self, float_format: str = ".3e") -> MarkdownStr: + """Format the :class:`ParameterGroup` as markdown string. + + This is done by recursing the nested :class:`ParameterGroup` tree. + + Parameters + ---------- + float_format: str + Format string for floating point numbers, by default ".3e" + + Returns + ------- + MarkdownStr : + The markdown representation as string. + """ + return param_dict_to_markdown(self.to_parameter_dict_or_list(), float_format=float_format) + + def _repr_markdown_(self) -> str: + """Create a markdown representation. + + Special method used by ``ipython`` to render markdown. + + Returns + ------- + str : + The markdown representation as string. + """ + return str(self.markdown()) + + @property + def labels(self) -> list[str]: + """List of all labels. + + Returns + ------- + list[str] + """ + return sorted(p.label for p in self.all()) + + def __str__(self) -> str: + """Representation used by print and str.""" + return str(self.markdown()) + + def __repr__(self) -> str: + """Representation debug.""" + params = [f"{p.label!r}: {repr(p)}" for p in self.all()] + return f"Parameters({{{', '.join(params)}}})" + + def __eq__(self, other: object) -> bool: + """==""" # noqa: D400 + if isinstance(other, Parameters): + return self.labels == other.labels and all( + self.get(label)._deep_equals(other.get(label)) for label in self.labels + ) + raise NotImplementedError( + "Parameters can only be compared with instances of Parameters, " + f"not with {type(other).__qualname__!r}." + ) + + +def flatten_parameter_dict( + parameter_dict: dict, +) -> Generator[tuple[str, list[Any], dict | None], None, None]: + """Flatten a parameter dictionary. + + Parameters + ---------- + parameter_dict: dict + The parameter dictionary. + + Yields + ------ + tuple[str, list[Any], dict | None + The concatenated keys, the parameter definition and default options. + """ + for k, v in parameter_dict.items(): + if isinstance(v, dict): + for sub_k, v, d in flatten_parameter_dict(v): + yield f"{k}.{sub_k}", v, d + elif isinstance(v, list): + defaults: dict[str, Any] | None = next( + (item for item in v if isinstance(item, dict)), None + ) + for i, v in enumerate(v for v in v if not isinstance(v, dict)): + if not isinstance(v, list): + v = [str(i + 1), v] + elif not any(isinstance(v, str) for v in v): + v += [str(i + 1)] + yield k, v, defaults + + +def param_dict_to_markdown( + parameters: dict | list, + float_format: str = ".3e", + depth: int = 0, + label: str | None = None, +) -> MarkdownStr: + """Format the :class:`Parameters` as markdown string. + + This is done by recursing the nested :class:`Parameters` tree. + + Parameters + ---------- + parameters: dict | list + The parameter dict or list. + float_format: str + Format string for floating point numbers, by default ".3e" + depth: int + The depth of the parameter dict. + label: str | None + The label of the parameter dict. + + Returns + ------- + MarkdownStr : + The markdown representation as string. + """ + node_indentation = " " * depth + return_string = "" + table_header = [ + "_Label_", + "_Value_", + "_Standard Error_", + "_t-value_", + "_Minimum_", + "_Maximum_", + "_Vary_", + "_Non-Negative_", + "_Expression_", + ] + if label is not None: + return_string += f"{node_indentation}* __{label}__:\n" + if isinstance(parameters, list): + parameter_rows = [ + [ + parameter.label_short, + parameter.value, + parameter.standard_error, + repr(pretty_format_numerical(parameter.value / parameter.standard_error)), + parameter.minimum, + parameter.maximum, + parameter.vary, + parameter.non_negative, + f"`{parameter.expression}`", + ] + for parameter in parameters + ] + parameter_table = indent( + tabulate( + parameter_rows, + headers=table_header, + tablefmt="github", + missingval="None", + floatfmt=float_format, + ), + f" {node_indentation}", + ) + return_string += f"\n{parameter_table}\n\n" + else: + for label, child in sorted(parameters.items()): + return_string += str( + param_dict_to_markdown( + child, float_format=float_format, depth=depth + 1, label=label + ) + ) + return MarkdownStr(return_string.replace("'", " ")) diff --git a/glotaran/parameter/test/test_parameter.py b/glotaran/parameter/test/test_parameter.py index c764d4dc8..019c285af 100644 --- a/glotaran/parameter/test/test_parameter.py +++ b/glotaran/parameter/test/test_parameter.py @@ -1,169 +1,151 @@ from __future__ import annotations import pickle +import re +from typing import Any import numpy as np import pytest -from glotaran.io import load_parameters from glotaran.parameter import Parameter -@pytest.mark.parametrize("label, expected", (("foo", "foo"), (0, "0"), (1, "1"), (None, None))) -def test_parameter_label_always_str_or_None(label: str | int | None, expected: str | None): - """Parameter.label is always a string or None""" +@pytest.mark.parametrize( + "key_name, value_1, value_2", + ( + ("value", 1, 2), + ("vary", True, False), + ("minimum", -np.inf, -1), + ("maximum", np.inf, 1), + ("expression", None, "$a.1*10"), + ("standard_error", np.nan, 1), + ("non_negative", True, False), + ), +) +def test_parameter__deep_equals(key_name: str, value_1: Any, value_2: Any): + parameter_1 = Parameter(label="foo", **{key_name: value_1}) + parameter_2 = Parameter(label="foo", **{key_name: value_1}) + assert parameter_1._deep_equals(parameter_2) + + parameter_3 = Parameter(label="foo", **{key_name: value_1}) + parameter_4 = Parameter(label="foo", **{key_name: value_2}) + assert not parameter_3._deep_equals(parameter_4) + + +@pytest.mark.parametrize("label, expected", (("foo", "foo"), (0, "0"), (1, "1"))) +def test_parameter_label_always_str_or_None(label: str | int, expected: str): + """Parameter.label is always a string""" parameter = Parameter(label=label) # type:ignore[arg-type] assert parameter.label == expected @pytest.mark.parametrize( "label", - (2.0, np.nan, "foo.bar"), + ("exp", np.nan), ) def test_parameter_label_error_wrong_label_pattern(label: str | int | float): """Error if label can't be casted to a valid label str""" - with pytest.raises(ValueError, match=f"'{label}' is not a valid group label."): + with pytest.raises(ValueError, match=f"'{label}' is not a valid parameter label."): Parameter(label=label) # type:ignore[arg-type] -def test_parameter_repr(): +@pytest.mark.parametrize( + "parameter, expected_repr", + ( + ( + Parameter(label="foo"), + "Parameter(label='foo')", + ), + ( + Parameter(label="foo", expression="$foo.bar", value=1.0, vary=True), + # vary gets set to False due to the usage of expression + "Parameter(label='foo', value=1.0, expression='$foo.bar', vary=False)", + ), + ), +) +def test_parameter_repr(parameter: Parameter, expected_repr: str): """Repr creates code to recreate the object.""" - result = Parameter(label="foo", value=1.0, expression="$foo.bar", vary=False) - result_short = Parameter(label="foo", value=1, expression="$foo.bar") - expected = "Parameter(label='foo', value=1.0, expression='$foo.bar', vary=False)" + print(parameter.__repr__()) + assert parameter.__repr__() == expected_repr + assert parameter._deep_equals(eval(expected_repr)) - assert result == result_short - assert result.vary == result_short.vary - assert result.__repr__() == expected - assert result_short.__repr__() == expected +def test_parameter_from_list(): + params = [["5", 1], ["4", 2], ["3", 3]] -def test_parameter_scientific_values(): - values = [5e3, -4.2e-4, 3e-2, -2e6] - params = """ - - ["1", 5e3] - - ["2", -4.2e-4] - - ["3", 3e-2] - - ["4", -2e6] - """ - - params = load_parameters(params, format_name="yml_str") + parameters = [Parameter.from_list(v) for v in params] - assert [p.value for _, p in params.all()] == values + assert [p.label for p in parameters] == [v[0] for v in params] + assert [p.value for p in parameters] == [v[1] for v in params] -def test_parameter_from_list(): - params = """ - - ["5", 1] - - ["4", 2] - - ["3", 3] - """ - - params = load_parameters(params, format_name="yml_str") +def test_parameter_options(): + params = [ + ["5", 1, {"non-negative": False, "min": -1, "max": 1, "vary": False}], + ["6", 4e2, {"non-negative": True, "min": -7e2, "max": 8e2, "vary": True}], + ["7", 2e4], + ] - assert len(list(params.all())) == 3 - assert [p.label for _, p in params.all()] == [f"{i}" for i in range(5, 2, -1)] - assert [p.value for _, p in params.all()] == list(range(1, 4)) + parameters = [Parameter.from_list(v) for v in params] + assert parameters[0].value == 1.0 + assert not parameters[0].non_negative + assert parameters[0].minimum == -1 + assert parameters[0].maximum == 1 + assert not parameters[0].vary -def test_parameter_options(): - params = """ - - ["5", 1, {non-negative: false, min: -1, max: 1, vary: false}] - - ["6", 4e2, {non-negative: true, min: -7e2, max: 8e2, vary: true}] - - ["7", 2e4] - """ - - params = load_parameters(params, format_name="yml_str") - - assert params.get("5").value == 1.0 - assert not params.get("5").non_negative - assert params.get("5").minimum == -1 - assert params.get("5").maximum == 1 - assert not params.get("5").vary - - assert params.get("6").value == 4e2 - assert params.get("6").non_negative - assert params.get("6").minimum == -7e2 - assert params.get("6").maximum == 8e2 - assert params.get("6").vary - - assert params.get("7").value == 2e4 - assert not params.get("7").non_negative - assert params.get("7").minimum == float("-inf") - assert params.get("7").maximum == float("inf") - assert params.get("7").vary - - -def test_parameter_block_options(): - params = """ - block: - - 1.0 - - [3.4, {vary: true}] - - {vary: false} - """ - - params = load_parameters(params, format_name="yml_str") - assert not params.get("block.1").vary - assert params.get("block.2").vary - - -def test_parameter_set_from_group(): - """Parameter extracted from group has correct values""" - group = load_parameters( - "foo:\n - [\"1\", 123,{non-negative: true, min: 10, max: 8e2, vary: true, expr:'2'}]", - format_name="yml_str", - ) - parameter = Parameter(full_label="foo.1") - parameter.set_from_group(group=group) + assert parameters[1].value == 4e2 + assert parameters[1].non_negative + assert parameters[1].minimum == -7e2 + assert parameters[1].maximum == 8e2 + assert parameters[1].vary - assert parameter.value == 123 - assert parameter.non_negative is True - assert np.allclose(parameter.minimum, 10) - assert np.allclose(parameter.maximum, 800) - assert parameter.vary is True - # Set to None since value and expr were provided? - assert parameter.expression is None + assert parameters[2].value == 2e4 + assert not parameters[2].non_negative + assert parameters[2].minimum == float("-inf") + assert parameters[2].maximum == float("inf") + assert parameters[2].vary def test_parameter_value_not_numeric_error(): """Error if value isn't numeric.""" - with pytest.raises(TypeError, match="Parameter value must be numeric"): - Parameter(value="foo") # type:ignore[arg-type] + with pytest.raises(TypeError): + Parameter(label="", value="foo") # type:ignore[arg-type] def test_parameter_maximum_not_numeric_error(): """Error if maximum isn't numeric.""" - with pytest.raises(TypeError, match="Parameter maximum must be numeric"): - Parameter(maximum="foo") # type:ignore[arg-type] + with pytest.raises(TypeError): + Parameter(label="", maximum="foo") # type:ignore[arg-type] def test_parameter_minimum_not_numeric_error(): """Error if minimum isn't numeric.""" - with pytest.raises(TypeError, match="Parameter minimum must be numeric"): - Parameter(minimum="foo") # type:ignore[arg-type] + with pytest.raises(TypeError): + Parameter(label="", minimum="foo") # type:ignore[arg-type] def test_parameter_non_negative(): - notnonneg = Parameter(value=1, non_negative=False) + notnonneg = Parameter(label="", value=1, non_negative=False) valuenotnoneg, _, _ = notnonneg.get_value_and_bounds_for_optimization() assert np.allclose(1, valuenotnoneg) notnonneg.set_value_from_optimization(valuenotnoneg) assert np.allclose(1, notnonneg.value) - nonneg1 = Parameter(value=1, non_negative=True) + nonneg1 = Parameter(label="", value=1, non_negative=True) value1, _, _ = nonneg1.get_value_and_bounds_for_optimization() assert not np.allclose(1, value1) nonneg1.set_value_from_optimization(value1) assert np.allclose(1, nonneg1.value) - nonneg2 = Parameter(value=2, non_negative=True) + nonneg2 = Parameter(label="", value=2, non_negative=True) value2, _, _ = nonneg2.get_value_and_bounds_for_optimization() assert not np.allclose(2, value2) nonneg2.set_value_from_optimization(value2) assert np.allclose(2, nonneg2.value) - nonnegminmax = Parameter(value=5, minimum=3, maximum=6, non_negative=True) + nonnegminmax = Parameter(label="", value=5, minimum=3, maximum=6, non_negative=True) value5, minimum, maximum = nonnegminmax.get_value_and_bounds_for_optimization() assert not np.allclose(5, value5) assert not np.allclose(3, minimum) @@ -173,37 +155,41 @@ def test_parameter_non_negative(): @pytest.mark.parametrize( "case", [ - ("$1", "group.get('1').value"), + ("$1", "parameters.get('1').value"), ( "1 - $kinetic.1 * exp($kinetic.2) + $kinetic.3", - "1 - group.get('kinetic.1').value * exp(group.get('kinetic.2').value) " - "+ group.get('kinetic.3').value", + "1 - parameters.get('kinetic.1').value * exp(parameters.get('kinetic.2').value) " + "+ parameters.get('kinetic.3').value", ), ("2", "2"), ( "1 - sum([$kinetic.1, $kinetic.2])", - "1 - sum([group.get('kinetic.1').value, group.get('kinetic.2').value])", + "1 - sum([parameters.get('kinetic.1').value, parameters.get('kinetic.2').value])", + ), + ("exp($kinetic.4)", "exp(parameters.get('kinetic.4').value)"), + ("$kinetic.5", "parameters.get('kinetic.5').value"), + ( + "$parameters.parameters.param1 + $kinetic6", + "parameters.get('parameters.parameters.param1').value " + "+ parameters.get('kinetic6').value", ), - ("exp($kinetic.4)", "exp(group.get('kinetic.4').value)"), - ("$kinetic.5", "group.get('kinetic.5').value"), ( - "$group.sub_group.param1 + $kinetic6", - "group.get('group.sub_group.param1').value + group.get('kinetic6').value", + "$foo.7.bar + $kinetic6", + "parameters.get('foo.7.bar').value " "+ parameters.get('kinetic6').value", ), - ("$foo.7.bar + $kinetic6", "group.get('foo.7.bar').value + group.get('kinetic6').value"), - ("$1", "group.get('1').value"), - ("$1-$2", "group.get('1').value-group.get('2').value"), - ("$1-$5", "group.get('1').value-group.get('5').value"), + ("$1", "parameters.get('1').value"), + ("$1-$2", "parameters.get('1').value-parameters.get('2').value"), + ("$1-$5", "parameters.get('1').value-parameters.get('5').value"), ( "100 - $inputs1.s1 - $inputs1.s3 - $inputs1.s8 - $inputs1.s12", - "100 - group.get('inputs1.s1').value - group.get('inputs1.s3').value " - "- group.get('inputs1.s8').value - group.get('inputs1.s12').value", + "100 - parameters.get('inputs1.s1').value - parameters.get('inputs1.s3').value " + "- parameters.get('inputs1.s8').value - parameters.get('inputs1.s12').value", ), ], ) def test_transform_expression(case): expression, wanted_parameters = case - parameter = Parameter(expression=expression) + parameter = Parameter(label="", expression=expression) assert parameter.transformed_expression == wanted_parameters # just for the test coverage so the if condition is wrong assert parameter.transformed_expression == wanted_parameters @@ -216,96 +202,39 @@ def test_label_validator(): "_valid2", "extra_valid3", ] - - assert all(list(map(Parameter.valid_label, valid_names))) + for label in valid_names: + Parameter(label) invalid_names = [ "testΓ©", - "kinetic.1", - "kinetic_red.3", - "foo.7.bar", - "_ilikeunderscoresatbegeninngin.justbecause", - "42istheanswer.42", "kinetic::red", "kinetic_blue+kinetic_red", "makesthissense=trueandfalse", "what/about\\slashes", "$invalid", "round", - "group", + "parameters", ] - assert not any(list(map(Parameter.valid_label, invalid_names))) - - -def test_parameter_expressions(): - params = """ - - ["1", 2] - - ["2", 5] - - ["3", {expr: '$1 * exp($2)'}] - - ["4", {expr: '2'}] - """ - - params = load_parameters(params, format_name="yml_str") - - assert params.get("3").expression is not None - assert not params.get("3").vary - assert params.get("3").value == 2 * np.exp(5) - assert params.get("3").value == params.get("1") * np.exp(params.get("2")) - assert params.get("4").value == 2 - - with pytest.raises(ValueError): - params_bad_expr = """ - - ["3", {expr: 'None'}] - """ - load_parameters(params_bad_expr, format_name="yml_str") - - -def test_parameter_expressions_groups(): - params_vary_explicit = """ - b: - - [0.25, {vary: True}] - - [0.75, {expr: '1 - $b.1', vary: False}] - rates: - - ["total", 2, {vary: True}] - - ["branch1", {expr: '$rates.total * $b.1', vary: False}] - - ["branch2", {expr: '$rates.total * $b.2', vary: False}] - """ - params_vary_implicit = """ - b: - - [0.25] - - [0.75, {expr: '1 - $b.1'}] - rates: - - ["total", 2] - - ["branch1", {expr: '$rates.total * $b.1'}] - - ["branch2", {expr: '$rates.total * $b.2'}] - """ - params_label_explicit = """ - b: - - ["1", 0.25] - - ["2", 0.75, {expr: '1 - $b.1'}] - rates: - - ["total", 2] - - ["branch1", {expr: '$rates.total * $b.1'}] - - ["branch2", {expr: '$rates.total * $b.2'}] - """ - - for params in [params_vary_explicit, params_vary_implicit, params_label_explicit]: - params = load_parameters(params, format_name="yml_str") - - assert params.get("b.1").expression is None - assert params.get("b.1").vary - assert not params.get("b.2").vary - assert params.get("b.2").expression is not None - assert params.get("rates.branch1").value == params.get("rates.total") * params.get("b.1") - assert params.get("rates.branch2").value == params.get("rates.total") * params.get("b.2") - assert params.get("rates.total").vary - assert not params.get("rates.branch1").vary - assert not params.get("rates.branch2").vary + + for label in invalid_names: + print(label) + with pytest.raises( + ValueError, match=re.escape(f"'{label}' is not a valid parameter label.") + ): + Parameter(label=label) def test_parameter_pickle(tmpdir): - parameter = Parameter("testlabel", "testlabelfull", "testexpression", 1, 2, True, 42, False) + parameter = Parameter( + label="testlabel", + expression="testexpression", + minimum=1, + maximum=2, + non_negative=True, + value=42, + vary=False, + ) with open(tmpdir.join("test_param_pickle"), "wb") as f: pickle.dump(parameter, f) @@ -317,11 +246,11 @@ def test_parameter_pickle(tmpdir): def test_parameter_numpy_operations(): """Operators work like a float""" - parm1 = Parameter(value=1) - parm1_neg = Parameter(value=-1) - parm2 = Parameter(value=2) - parm3 = Parameter(value=3) - parm3_5 = Parameter(value=3.5) + parm1 = Parameter(label="foo", value=1) + parm1_neg = Parameter(label="foo", value=-1) + parm2 = Parameter(label="foo", value=2) + parm3 = Parameter(label="foo", value=3) + parm3_5 = Parameter(label="foo", value=3.5) assert parm1 == 1 assert parm1 != 2 @@ -349,7 +278,6 @@ def test_parameter_numpy_operations(): def test_parameter_dict_roundtrip(): param = Parameter( label="foo", - full_label="bar.foo", expression="1", maximum=2, minimum=1, @@ -359,13 +287,37 @@ def test_parameter_dict_roundtrip(): ) param_dict = param.as_dict() - param_from_dict = Parameter.from_dict(param_dict) + print(param_dict) + param_from_dict = Parameter(**param_dict) assert param.label == param_from_dict.label - assert param.full_label == param_from_dict.full_label assert param.expression == param_from_dict.expression assert param.maximum == param_from_dict.maximum assert param.minimum == param_from_dict.minimum assert param.non_negative == param_from_dict.non_negative assert param.value == param_from_dict.value assert param.vary == param_from_dict.vary + + +def test_parameter_list_roundtrip(): + param = Parameter( + label="foo", + expression="1", + maximum=2, + minimum=1, + non_negative=True, + value=42, + vary=False, + ) + + param_list = param.as_list() + print(param_list) + param_from_list = Parameter.from_list(param_list) + + assert param.label == param_from_list.label + assert param.expression == param_from_list.expression + assert param.maximum == param_from_list.maximum + assert param.minimum == param_from_list.minimum + assert param.non_negative == param_from_list.non_negative + assert param.value == param_from_list.value + assert param.vary == param_from_list.vary diff --git a/glotaran/parameter/test/test_parameter_group.py b/glotaran/parameter/test/test_parameter_group.py deleted file mode 100644 index 2fea4cc34..000000000 --- a/glotaran/parameter/test/test_parameter_group.py +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import annotations - -from textwrap import dedent - -import numpy as np - -from glotaran.io import load_parameters -from glotaran.io import save_parameters -from glotaran.parameter import ParameterGroup - - -def test_parameter_group_copy(): - params = """ - kinetic: - - ["5", 1, {non-negative: true, min: -1, max: 1, vary: false}] - - 4 - - 5 - j: - - 7 - - 8 - """ - params = load_parameters(params, format_name="yml_str") - copy = params.copy() - - for label, parameter in params.all(): - assert copy.has(label) - copied_parameter = copy.get(label) - assert parameter.value == copied_parameter.value - assert parameter.non_negative == copied_parameter.non_negative - assert parameter.minimum == copied_parameter.minimum - assert parameter.maximum == copied_parameter.maximum - assert parameter.vary == copied_parameter.vary - - -def test_parameter_group_from_list(): - params = """ - - 5 - - 4 - - 3 - - 2 - - 1 - """ - - params = load_parameters(params, format_name="yml_str") - - assert len(list(params.all())) == 5 - - assert [p.label for _, p in params.all()] == [f"{i}" for i in range(1, 6)] - assert [p.value for _, p in params.all()] == list(range(1, 6))[::-1] - - -def test_parameter_group_from_dict(): - params = """ - kinetic: - - 3 - - 4 - - 5 - j: - - 7 - - 8 - """ - - params = load_parameters(params, format_name="yml_str") - - assert len(list(params.all())) == 5 - group = params["kinetic"] - assert len(list(group.all())) == 3 - assert [p.label for _, p in group.all()] == [f"{i}" for i in range(1, 4)] - assert [p.value for _, p in group.all()] == list(range(3, 6)) - group = params["j"] - assert len(list(group.all())) == 2 - assert [p.label for _, p in group.all()] == [f"{i}" for i in range(1, 3)] - assert [p.value for _, p in group.all()] == list(range(7, 9)) - - -def test_parameter_group_from_dict_nested(): - params = """ - kinetic: - j: - - 7 - - 8 - - 9 - """ - - params = load_parameters(params, format_name="yml_str") - assert len(list(params.all())) == 3 - group = params["kinetic"] - assert len(list(group.all())) == 3 - group = group["j"] - assert len(list(group.all())) == 3 - assert [p.label for _, p in group.all()] == [f"{i}" for i in range(1, 4)] - assert [p.value for _, p in group.all()] == list(range(7, 10)) - - assert params.get("kinetic.j.1").full_label == "kinetic.j.1" - - roundtrip_df = ParameterGroup.from_dataframe(params.to_dataframe()).to_dataframe() - assert all(roundtrip_df.label == params.to_dataframe().label) - - -def test_parameter_group_to_array(): - params = """ - - ["1", 1, {non-negative: false, min: -1, max: 1, vary: false}] - - ["2", 4e2, {non-negative: true, min: 10, max: 8e2, vary: true}] - - ["3", 2e4] - """ - - params = load_parameters(params, format_name="yml_str") - - labels, values, lower_bounds, upper_bounds = params.get_label_value_and_bounds_arrays( - exclude_non_vary=False - ) - - assert len(labels) == 3 - assert len(values) == 3 - assert len(lower_bounds) == 3 - assert len(upper_bounds) == 3 - - assert labels == ["1", "2", "3"] - assert np.allclose(values, [1, np.log(4e2), 2e4]) - assert np.allclose(lower_bounds, [-1, np.log(10), -np.inf]) - assert np.allclose(upper_bounds, [1, np.log(8e2), np.inf]) - - ( - labels_only_vary, - values_only_vary, - lower_bounds_only_vary, - upper_bounds_only_vary, - ) = params.get_label_value_and_bounds_arrays(exclude_non_vary=True) - - assert len(labels_only_vary) == 2 - assert len(values_only_vary) == 2 - assert len(lower_bounds_only_vary) == 2 - assert len(upper_bounds_only_vary) == 2 - - assert labels_only_vary == ["2", "3"] - - -def test_parameter_group_set_from_label_and_value_arrays(): - params = """ - - ["1", 1, {non-negative: false, min: -1, max: 1, vary: false}] - - ["2", 4e2, {non-negative: true, min: 10, max: 8e2, vary: true}] - - ["3", 2e4] - """ - - params = load_parameters(params, format_name="yml_str") - - labels = ["1", "2", "3"] - values = [0, np.log(6e2), 42] - - params.set_from_label_and_value_arrays(labels, values) - - values[1] = np.exp(values[1]) - - for i in range(3): - assert params.get(f"{i+1}").value == values[i] - - -def test_parameter_group_from_csv(tmpdir): - - TEST_CSV = dedent( - """\ - label, value, minimum, maximum, vary, non-negative, expression - rates.k1,0.050,0,5,True,True,None - rates.k2,None,,,True,True,$rates.k1 * 2 - rates.k3,2.311,,,True,True,None - pen.eq.1,1.000,,,False,False,None - """ - ) - - csv_path = tmpdir.join("parameters.csv") - with open(csv_path, "w") as f: - f.write(TEST_CSV) - - params = load_parameters(csv_path) - - assert "rates" in params - - assert params.has("rates.k1") - p = params.get("rates.k1") - assert p.label == "k1" - assert p.value == 0.05 - assert p.minimum == 0 - assert p.maximum == 5 - assert p.vary - assert p.non_negative - assert p.expression is None - - assert params.has("rates.k2") - p = params.get("rates.k2") - assert p.label == "k2" - assert p.value == params.get("rates.k1") * 2 - assert p.minimum == -np.inf - assert p.maximum == np.inf - assert not p.vary - assert not p.non_negative - assert p.expression == "$rates.k1 * 2" - - assert params.has("rates.k3") - p = params.get("rates.k3") - assert p.label == "k3" - assert p.value == 2.311 - assert p.minimum == -np.inf - assert p.maximum == np.inf - assert p.vary - assert p.non_negative - assert p.expression is None - - assert "pen" in params - assert "eq" in params["pen"] - - assert params.has("pen.eq.1") - p = params.get("pen.eq.1") - assert p.label == "1" - assert p.value == 1.0 - assert p.minimum == -np.inf - assert p.maximum == np.inf - assert not p.vary - assert not p.non_negative - assert p.expression is None - - -def test_parameter_group_to_csv(tmpdir): - csv_path = tmpdir.join("parameters.csv") - params = load_parameters( - """ - b: - - ["1", 0.25, {vary: false, min: 0, max: 8}] - - ["2", 0.75, {expr: '1 - $b.1', non-negative: true}] - rates: - - ["total", 2] - - ["branch1", {expr: '$rates.total * $b.1'}] - - ["branch2", {expr: '$rates.total * $b.2'}] - """, - format_name="yml_str", - ) - for _, p in params.all(): - p.standard_error = 42 - - save_parameters(params, csv_path, "csv") - wanted = dedent( - """\ - label,value,expression,minimum,maximum,non-negative,vary,standard-error - b.1,0.25,None,0.0,8.0,False,False,42 - b.2,0.75,1 - $b.1,,,False,False,42 - rates.total,2.0,None,,,False,True,42 - rates.branch1,0.5,$rates.total * $b.1,,,False,False,42 - rates.branch2,1.5,$rates.total * $b.2,,,False,False,42 - """ - ) - - with open(csv_path) as f: - got = f.read() - print(got) - assert got == wanted - params_from_csv = load_parameters(csv_path) - - for label, p in params.all(): - assert params_from_csv.has(label) - p_from_csv = params_from_csv.get(label) - assert p.label == p_from_csv.label - assert p.value == p_from_csv.value - assert p.minimum == p_from_csv.minimum - assert p.maximum == p_from_csv.maximum - assert p.vary == p_from_csv.vary - assert p.non_negative == p_from_csv.non_negative - assert p.expression == p_from_csv.expression - - -def test_parameter_group_to_from_parameter_dict_list(): - parameter_group = load_parameters( - """ - b: - - ["1", 0.25, {vary: false, min: 0, max: 8}] - - ["2", 0.75, {expr: '1 - $b.1', non-negative: true}] - rates: - - ["total", 2] - - ["branch1", {expr: '$rates.total * $b.1'}] - - ["branch2", {expr: '$rates.total * $b.2'}] - """, - format_name="yml_str", - ) - - parameter_dict_list = parameter_group.to_parameter_dict_list() - parameter_group_from_dict_list = ParameterGroup.from_parameter_dict_list(parameter_dict_list) - - for label, wanted in parameter_group.all(): - got = parameter_group_from_dict_list.get(label) - - assert got.label == wanted.label - assert got.full_label == wanted.full_label - assert got.expression == wanted.expression - assert got.maximum == wanted.maximum - assert got.minimum == wanted.minimum - assert got.non_negative == wanted.non_negative - assert got.value == wanted.value - assert got.vary == wanted.vary - - -def test_parameter_group_to_from_df(): - parameter_group = load_parameters( - """ - b: - - ["1", 0.25, {vary: false, min: 0, max: 8}] - - ["2", 0.75, {expr: '1 - $b.1', non-negative: true}] - rates: - - ["total", 2] - - ["branch1", {expr: '$rates.total * $b.1'}] - - ["branch2", {expr: '$rates.total * $b.2'}] - """, - format_name="yml_str", - ) - - for _, p in parameter_group.all(): - p.standard_error = 42 - - parameter_df = parameter_group.to_dataframe() - - for column in [ - "label", - "value", - "standard-error", - "expression", - "minimum", - "maximum", - "non-negative", - "vary", - ]: - assert column in parameter_df - - assert all(parameter_df["standard-error"] == 42) - - parameter_group_from_df = ParameterGroup.from_dataframe(parameter_df) - - for label, wanted in parameter_group.all(): - got = parameter_group_from_df.get(label) - - assert got.label == wanted.label - assert got.full_label == wanted.full_label - assert got.expression == wanted.expression - assert got.maximum == wanted.maximum - assert got.minimum == wanted.minimum - assert got.non_negative == wanted.non_negative - assert got.value == wanted.value - assert got.vary == wanted.vary - - -def test_missing_parameter_value_labels(): - """Full labels of all parameters with missing values (NaN) get listed.""" - parameter_group = load_parameters( - dedent( - """\ - b: - - ["missing_value_1",] - - ["missing_value_2"] - - ["2", 0.75] - kinetic: - j: - - ["missing_value_3"] - """ - ), - format_name="yml_str", - ) - - assert parameter_group.missing_parameter_value_labels == [ - "b.missing_value_1", - "b.missing_value_2", - "kinetic.j.missing_value_3", - ] diff --git a/glotaran/parameter/test/test_parameter_history.py b/glotaran/parameter/test/test_parameter_history.py index 08fb54a2e..a236ac263 100644 --- a/glotaran/parameter/test/test_parameter_history.py +++ b/glotaran/parameter/test/test_parameter_history.py @@ -1,13 +1,13 @@ import numpy as np -from glotaran.parameter.parameter_group import ParameterGroup from glotaran.parameter.parameter_history import ParameterHistory +from glotaran.parameter.parameters import Parameters def test_parameter_history(): - group0 = ParameterGroup.from_list([["1", 1], ["2", 4]]) - group1 = ParameterGroup.from_list([["1", 2], ["2", 5]]) - group2 = ParameterGroup.from_list([["1", 3], ["2", 6]]) + group0 = Parameters.from_list([["1", 1], ["2", 4]]) + group1 = Parameters.from_list([["1", 2], ["2", 5]]) + group2 = Parameters.from_list([["1", 3], ["2", 6]]) history = ParameterHistory() diff --git a/glotaran/parameter/test/test_parameters.py b/glotaran/parameter/test/test_parameters.py new file mode 100644 index 000000000..6c317243d --- /dev/null +++ b/glotaran/parameter/test/test_parameters.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from glotaran.parameter import Parameter +from glotaran.parameter import Parameters + + +def test_parameters_from_list(): + params = [5, 4, 3, 2, 1] + + parameters = Parameters.from_list(params) + + print(parameters._parameters) + assert len(list(parameters.all())) == 5 + + assert [p.label for p in parameters.all()] == [f"{i}" for i in range(1, 6)] + assert [p.value for p in parameters.all()] == list(range(1, 6))[::-1] + + +def test_parameters_from_dict(): + params = {"a": [3, 4, 5], "b": [7, 8]} + + parameters = Parameters.from_dict(params) + + assert len(list(parameters.all())) == 5 + + for label, value in [ + ("a.1", 3), + ("a.2", 4), + ("a.3", 5), + ("b.1", 7), + ("b.2", 8), + ]: + assert parameters.has(label) + assert parameters.get(label).label == label + assert parameters.get(label).value == value + + +def test_parameters_from_dict_nested(): + params = {"a": {"b": [7, 8, 9]}} + + parameters = Parameters.from_dict(params) + assert len(list(parameters.all())) == 3 + + for label, value in [ + ("a.b.1", 7), + ("a.b.2", 8), + ("a.b.3", 9), + ]: + assert parameters.has(label) + assert parameters.get(label).label == label + assert parameters.get(label).value == value + + +def test_parameters_default_options(): + params = {"block": [1.0, [3.4, {"vary": True}], {"vary": False}]} + + parameters = Parameters.from_dict(params) + assert len(list(parameters.all())) == 2 + + assert not parameters.get("block.1").vary + assert parameters.get("block.2").vary + + +def test_parameter_group_to_from_parameter_dict_list(): + parameters = Parameters.from_dict( + { + "a": [ + ["1", 0.25, {"vary": False, "min": 0, "max": 8}], + ["2", 0.75, {"expr": "1 - $a.1", "non-negative": True}], + ], + "b": [ + ["total", 2], + ["branch1", {"expr": "$b.total * $a.1"}], + ["branch2", {"expr": "$b.total * $a.2"}], + ], + } + ) + + parameters_dict_list = parameters.to_parameter_dict_list() + + assert parameters == Parameters.from_parameter_dict_list(parameters_dict_list) + + +def test_parameters_equal(): + """Instances of ``Parameters`` that have the same values are equal.""" + params = [2, 1] + + parameters_1 = Parameters.from_list(params) + parameters_2 = Parameters.from_list(params) + + assert parameters_1 == parameters_2 + + +@pytest.mark.parametrize( + "key_name, value_1, value_2", + ( + ("vary", True, False), + ("min", -np.inf, -1), + ("max", np.inf, 1), + ("expression", None, "$a.1*10"), + ("standard-error", np.nan, 1), + ("non-negative", True, False), + ), +) +def test_parameters_not_equal(key_name: str, value_1: Any, value_2: Any): + """Instances of ``Parameters`` that have the same values are equal.""" + parameters_1 = Parameters.from_dict({"a": [["1", 0.25, {key_name: value_1}]]}) + parameters_2 = Parameters.from_dict({"a": [["1", 0.25, {key_name: value_2}]]}) + + assert parameters_1 != parameters_2 + + +def test_parameters_equal_error(): + """Raise if rhs operator is not an instance of ``Parameters``.""" + param_dict = {"foo": Parameter(label="foo")} + with pytest.raises(NotImplementedError) as excinfo: + Parameters(param_dict) == param_dict + + assert ( + str(excinfo.value) + == "Parameters can only be compared with instances of Parameters, not with 'dict'." + ) + + +def test_parameter_scientific_values(): + values = [5e3, -4.2e-4, 3e-2, -2e6] + assert [p.value for p in Parameters.from_list(values).all()] == values + + +def test_parameter_group_copy(): + parameters = Parameters.from_dict( + { + "a": [ + ["1", 0.25, {"vary": False, "min": 0, "max": 8}], + ["2", 0.75, {"expr": "1 - $a.1", "non-negative": True}], + ], + "b": [ + ["total", 2], + ["branch1", {"expr": "$b.total * $a.1"}], + ["branch2", {"expr": "$b.total * $a.2"}], + ], + } + ) + + copy = parameters.copy() + + assert parameters is not copy + assert parameters == parameters.copy() + + +def test_parameter_expressions(): + parameters = Parameters.from_list( + [["1", 2], ["2", 5], ["3", {"expr": "$1 * exp($2)"}], ["4", {"expr": "2"}]] + ) + + assert parameters.get("3").expression is not None + assert not parameters.get("3").vary + assert parameters.get("3").value == 2 * np.exp(5) + assert parameters.get("3").value == parameters.get("1") * np.exp(parameters.get("2")) + assert parameters.get("4").value == 2 + + with pytest.raises(ValueError): + Parameters.from_list([["3", {"expr": "None"}]]) + + +def test_parameters_array_conversion(): + parameters = Parameters.from_list( + [ + ["1", 1, {"non-negative": False, "min": -1, "max": 1, "vary": False}], + ["2", 4e2, {"non-negative": True, "min": 10, "max": 8e2, "vary": True}], + ["3", 2e4], + ] + ) + + labels, values, lower_bounds, upper_bounds = parameters.get_label_value_and_bounds_arrays( + exclude_non_vary=False + ) + + assert len(labels) == 3 + assert len(values) == 3 + assert len(lower_bounds) == 3 + assert len(upper_bounds) == 3 + + assert labels == ["1", "2", "3"] + assert np.allclose(values, [1, np.log(4e2), 2e4]) + assert np.allclose(lower_bounds, [-1, np.log(10), -np.inf]) + assert np.allclose(upper_bounds, [1, np.log(8e2), np.inf]) + + ( + labels_only_vary, + values_only_vary, + lower_bounds_only_vary, + upper_bounds_only_vary, + ) = parameters.get_label_value_and_bounds_arrays(exclude_non_vary=True) + + assert len(labels_only_vary) == 2 + assert len(values_only_vary) == 2 + assert len(lower_bounds_only_vary) == 2 + assert len(upper_bounds_only_vary) == 2 + + assert labels_only_vary == ["2", "3"] + + labels = ["1", "2", "3"] + values = [0, np.log(6e2), 42] + + parameters.set_from_label_and_value_arrays(labels, values) + + values[1] = np.exp(values[1]) + + for i in range(3): + assert parameters.get(f"{i+1}").value == values[i] + + +def test_parameter_group_to_from_df(): + parameters = Parameters.from_dict( + { + "a": [ + ["1", 0.25, {"vary": False, "min": 0, "max": 8}], + ["2", 0.75, {"expr": "1 - $a.1", "non-negative": True}], + ], + "b": [ + ["total", 2], + ["branch1", {"expr": "$b.total * $a.1"}], + ["branch2", {"expr": "$b.total * $a.2"}], + ], + } + ) + + for p in parameters.all(): + p.standard_error = 42 + + parameter_df = parameters.to_dataframe() + + for column in [ + "label", + "value", + "standard_error", + "expression", + "minimum", + "maximum", + "non_negative", + "vary", + ]: + assert column in parameter_df + + assert all(parameter_df["standard_error"] == 42) + + assert parameters == Parameters.from_dataframe(parameter_df) diff --git a/glotaran/parameter/test/test_parameter_group_rendering.py b/glotaran/parameter/test/test_parameters_rendering.py similarity index 76% rename from glotaran/parameter/test/test_parameter_group_rendering.py rename to glotaran/parameter/test/test_parameters_rendering.py index 9452533ec..c10ef9e6b 100644 --- a/glotaran/parameter/test/test_parameter_group_rendering.py +++ b/glotaran/parameter/test/test_parameters_rendering.py @@ -1,7 +1,7 @@ from IPython.core.formatters import format_display_data from glotaran.io import load_parameters -from glotaran.parameter.parameter_group import ParameterGroup +from glotaran.parameter.parameters import Parameters PARAMETERS_3C_BASE = """\ irf: @@ -52,12 +52,12 @@ """ # noqa: E501 -def test_param_group_markdown_is_order_independent(): - """Markdown output of ParameterGroup.markdown() is independent of initial order""" +def test_parameters_markdown_is_order_independent(): + """Markdown output of Parameters.markdown() is independent of initial order""" PARAMETERS_3C_INITIAL1 = f"""{PARAMETERS_3C_BASE}\n{PARAMETERS_3C_KINETIC}""" PARAMETERS_3C_INITIAL2 = f"""{PARAMETERS_3C_KINETIC}\n{PARAMETERS_3C_BASE}""" - initial_parameters_ref = ParameterGroup.from_dict( + initial_parameters_ref = Parameters.from_dict( { "j": [["1", 1, {"vary": False, "non-negative": False}]], "kinetic": [ @@ -76,40 +76,45 @@ def test_param_group_markdown_is_order_independent(): assert str(initial_parameters2.markdown()) == RENDERED_MARKDOWN assert str(initial_parameters_ref.markdown()) == RENDERED_MARKDOWN - minimal_params = ParameterGroup.from_dict( + minimal_params = Parameters.from_dict( {"irf": [["center", 1.3, {"standard-error": 0.000012345678}]]} ) assert str(minimal_params.markdown(float_format=".5e")) == RENDERED_MARKDOWN_E5_PRECISION -def test_param_group_repr(): +def test_parameters_repr(): """Repr creates code to recreate the object with from_dict.""" - result = ParameterGroup.from_dict({"foo": {"bar": [["1", 1.0], ["2", 2.0], ["3", 3.0]]}}) - result_short = ParameterGroup.from_dict({"foo": {"bar": [1, 2, 3]}}) - expected = "ParameterGroup.from_dict({'foo': {'bar': [['1', 1.0], ['2', 2.0], ['3', 3.0]]}})" - - assert result == result_short - assert result_short.__repr__() == expected - assert result.__repr__() == expected - assert result == eval(result.__repr__()) + # Needed to eval the Parameters repr + from glotaran.parameter.parameter import Parameter # noqa:401 -def test_param_group_repr_from_list(): - """Repr creates code to recreate the object with from_list.""" - result = ParameterGroup.from_list([["1", 2.3], ["2", 3.0]]) - result_short = ParameterGroup.from_list([2.3, 3.0]) - expected = "ParameterGroup.from_list([['1', 2.3], ['2', 3.0]])" + result = Parameters.from_dict( + { + "foo": { + "bar": [ + ["1", 1.0, {"vary": True}], + ["2", 2.0, {"expression": "$foo.bar.1*2"}], + ["3", 3.0, {"min": -10}], + ] + } + } + ) + expected = ( + "Parameters({'foo.bar.1': Parameter(label='foo.bar.1', value=1.0), " + "'foo.bar.2': Parameter(label='foo.bar.2', value=2.0, expression='$foo.bar.1*2'," + " vary=False), " + "'foo.bar.3': Parameter(label='foo.bar.3', value=3.0, minimum=-10)})" + ) - assert result == result_short + print(result.__repr__()) assert result.__repr__() == expected - assert result_short.__repr__() == expected - assert result == eval(result.__repr__()) + assert result == eval(expected) -def test_param_group_ipython_rendering(): +def test_parameters_ipython_rendering(): """Autorendering in ipython""" - param_group = ParameterGroup.from_dict({"foo": {"bar": [["1", 1.0], ["2", 2.0], ["3", 3.0]]}}) + param_group = Parameters.from_dict({"foo": {"bar": [["1", 1.0], ["2", 2.0], ["3", 3.0]]}}) rendered_obj = format_display_data(param_group)[0] assert "text/markdown" in rendered_obj diff --git a/glotaran/plugin_system/project_io_registration.py b/glotaran/plugin_system/project_io_registration.py index c981702a0..8c5ca1759 100644 --- a/glotaran/plugin_system/project_io_registration.py +++ b/glotaran/plugin_system/project_io_registration.py @@ -38,7 +38,7 @@ from typing import Literal from glotaran.model import Model - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters from glotaran.project import Result from glotaran.project import Scheme from glotaran.typing import StrOrPath @@ -266,8 +266,8 @@ def save_model( @not_implemented_to_value_error -def load_parameters(file_name: StrOrPath, format_name: str = None, **kwargs) -> ParameterGroup: - """Create a :class:`ParameterGroup` instance from the specs defined in a file. +def load_parameters(file_name: StrOrPath, format_name: str = None, **kwargs) -> Parameters: + """Create a :class:`Parameters` instance from the specs defined in a file. Parameters ---------- @@ -281,8 +281,10 @@ def load_parameters(file_name: StrOrPath, format_name: str = None, **kwargs) -> Returns ------- - ParameterGroup - :class:`ParameterGroup` instance created from the file. + Parameters + :class:`Parameters` instance created from the file. + + .. # noqa: D414 """ io = get_project_io(format_name or inferr_file_format(file_name)) parameters = io.load_parameters( # type: ignore[call-arg] @@ -295,7 +297,7 @@ def load_parameters(file_name: StrOrPath, format_name: str = None, **kwargs) -> @not_implemented_to_value_error def save_parameters( - parameters: ParameterGroup, + parameters: Parameters, file_name: StrOrPath, format_name: str = None, *, @@ -303,12 +305,12 @@ def save_parameters( update_source_path: bool = True, **kwargs: Any, ) -> None: - """Save a :class:`ParameterGroup` instance to a spec file. + """Save a :class:`Parameters` instance to a spec file. Parameters ---------- - parameters : ParameterGroup - :class:`ParameterGroup` instance to save to specs file. + parameters : Parameters + :class:`Parameters` instance to save to specs file. file_name : StrOrPath File to write the parameter specs to. format_name : str diff --git a/glotaran/plugin_system/test/test_megacomplex_registration.py b/glotaran/plugin_system/test/test_megacomplex_registration.py index 64a2865bd..f5def08bb 100644 --- a/glotaran/plugin_system/test/test_megacomplex_registration.py +++ b/glotaran/plugin_system/test/test_megacomplex_registration.py @@ -61,9 +61,9 @@ def test_register_megacomplex_warning(): with pytest.warns(PluginOverwriteWarning, match="DecayMegacomplex.+bar.+Dummy") as record: - @megacomplex(register_as="bar") + @megacomplex() class Dummy(DecayMegacomplex): - pass + type: str = "bar" assert len(record) == 1 assert Path(record[0].filename) == Path(__file__) diff --git a/glotaran/plugin_system/test/test_project_io_registration.py b/glotaran/plugin_system/test/test_project_io_registration.py index ef026311e..28c84812f 100644 --- a/glotaran/plugin_system/test/test_project_io_registration.py +++ b/glotaran/plugin_system/test/test_project_io_registration.py @@ -9,7 +9,7 @@ from glotaran.builtin.io.pandas.csv import CsvProjectIo from glotaran.builtin.io.yml.yml import YmlProjectIo from glotaran.io import ProjectIoInterface -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.plugin_system.base_registry import PluginOverwriteWarning from glotaran.plugin_system.base_registry import __PluginRegistry from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT @@ -71,14 +71,14 @@ def save_model( # type:ignore[override] } ) - def load_parameters(self, file_name: StrOrPath, **kwargs: Any) -> ParameterGroup: + def load_parameters(self, file_name: StrOrPath, **kwargs: Any) -> Parameters: mock_obj = MockFileLoadable() mock_obj.func_args = {"file_name": file_name, **kwargs} return mock_obj # type:ignore[return-value] def save_parameters( # type:ignore[override] self, - parameters: ParameterGroup, + parameters: Parameters, file_name: StrOrPath, **kwargs: Any, ): diff --git a/glotaran/project/generators/generator.py b/glotaran/project/generators/generator.py index 8c0043181..09bb69191 100644 --- a/glotaran/project/generators/generator.py +++ b/glotaran/project/generators/generator.py @@ -7,6 +7,9 @@ from typing import cast from glotaran.builtin.io.yml.utils import write_dict +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.model import Model @@ -219,7 +222,9 @@ def generate_model(*, generator_name: str, generator_arguments: GeneratorArgumen f"Known generators are: {list(generators.keys())}" ) model = generators[generator_name](**generator_arguments) - return Model.from_dict(model) + return Model.create_class_from_megacomplexes( + [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] + )(**model) def generate_model_yml(*, generator_name: str, generator_arguments: GeneratorArguments) -> str: diff --git a/glotaran/project/generators/test/test_genenerate_decay_model.py b/glotaran/project/generators/test/test_genenerate_decay_model.py index dda3cce29..db6b2ec21 100644 --- a/glotaran/project/generators/test/test_genenerate_decay_model.py +++ b/glotaran/project/generators/test/test_genenerate_decay_model.py @@ -1,5 +1,8 @@ import pytest +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.project.generators.generator import generate_model @@ -25,11 +28,10 @@ def test_generate_parallel_model(megacomplex_type: str, irf: bool, spectral: boo megacomplex = model.megacomplex[ # type:ignore[attr-defined] f"megacomplex_{megacomplex_type}_decay" ] + assert isinstance(megacomplex, (DecayParallelMegacomplex, DecaySequentialMegacomplex)) assert megacomplex.type == f"decay-{megacomplex_type}" assert megacomplex.compartments == expected_compartments - assert [r.full_label for r in megacomplex.rates] == [ - f"rates.species_{i+1}" for i in range(nr_compartments) - ] + assert megacomplex.rates == [f"rates.species_{i+1}" for i in range(nr_compartments)] assert "dataset_1" in model.dataset # type:ignore[attr-defined] dataset = model.dataset["dataset_1"] # type:ignore[attr-defined] @@ -38,6 +40,7 @@ def test_generate_parallel_model(megacomplex_type: str, irf: bool, spectral: boo if spectral: assert "megacomplex_spectral" in model.megacomplex # type:ignore[attr-defined] megacomplex = model.megacomplex["megacomplex_spectral"] # type:ignore[attr-defined] + assert isinstance(megacomplex, SpectralMegacomplex) assert expected_compartments == list(megacomplex.shape.keys()) expected_shapes = [f"shape_species_{i+1}" for i in range(nr_compartments)] assert expected_shapes == list(megacomplex.shape.values()) @@ -46,15 +49,15 @@ def test_generate_parallel_model(megacomplex_type: str, irf: bool, spectral: boo assert shape in model.shape # type:ignore[attr-defined] assert model.shape[shape].type == "gaussian" # type:ignore[attr-defined] assert ( - model.shape[shape].amplitude.full_label # type:ignore[attr-defined] + model.shape[shape].amplitude # type:ignore[attr-defined] == f"shapes.species_{i+1}.amplitude" ) assert ( - model.shape[shape].location.full_label # type:ignore[attr-defined] + model.shape[shape].location # type:ignore[attr-defined] == f"shapes.species_{i+1}.location" ) assert ( - model.shape[shape].width.full_label # type:ignore[attr-defined] + model.shape[shape].width # type:ignore[attr-defined] == f"shapes.species_{i+1}.width" ) assert dataset.global_megacomplex == ["megacomplex_spectral"] @@ -63,9 +66,7 @@ def test_generate_parallel_model(megacomplex_type: str, irf: bool, spectral: boo assert dataset.irf == "gaussian_irf" # type:ignore[attr-defined] assert "gaussian_irf" in model.irf # type:ignore[attr-defined] assert ( - model.irf["gaussian_irf"].center.full_label # type:ignore[attr-defined] + model.irf["gaussian_irf"].center # type:ignore[attr-defined] == "irf.center" ) - assert ( - model.irf["gaussian_irf"].width.full_label == "irf.width" # type:ignore[attr-defined] - ) + assert model.irf["gaussian_irf"].width == "irf.width" # type:ignore[attr-defined] diff --git a/glotaran/project/project.py b/glotaran/project/project.py index 1980d1e59..07d15e780 100644 --- a/glotaran/project/project.py +++ b/glotaran/project/project.py @@ -14,7 +14,7 @@ from glotaran.builtin.io.yml.utils import load_dict from glotaran.model import Model -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project.project_data_registry import ProjectDataRegistry from glotaran.project.project_model_registry import ProjectModelRegistry from glotaran.project.project_parameter_registry import ProjectParameterRegistry @@ -305,7 +305,7 @@ def parameters(self) -> dict[str, Path]: """ return self._parameter_registry.items - def load_parameters(self, parameters_name: str) -> ParameterGroup: + def load_parameters(self, parameters_name: str) -> Parameters: """Load parameters. Parameters @@ -313,15 +313,17 @@ def load_parameters(self, parameters_name: str) -> ParameterGroup: parameters_name : str The name of the parameters. - Returns - ------- - ParameterGroup - The loaded parameters. - Raises ------ ValueError Raised if parameters do not exist. + + Returns + ------- + Parameters + The loaded parameters. + + .. # noqa: D414 """ try: return self._parameter_registry.load_item(parameters_name) diff --git a/glotaran/project/project_parameter_registry.py b/glotaran/project/project_parameter_registry.py index c37ac27d2..72e8fd2c1 100644 --- a/glotaran/project/project_parameter_registry.py +++ b/glotaran/project/project_parameter_registry.py @@ -8,7 +8,7 @@ from glotaran.io import load_parameters from glotaran.io import save_parameters from glotaran.model import Model -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.project.project_registry import ProjectRegistry @@ -54,7 +54,9 @@ def generate_parameters( FileExistsError Raised if parameters is already existing and `allow_overwrite=False`. """ - parameters = model.generate_parameters() + parameters = model.generate_parameters().to_parameter_dict_or_list( + serialize_parameters=True + ) parameter_file = self.directory / f"{name}.{format_name}" if parameter_file.exists() and ignore_existing: @@ -64,12 +66,14 @@ def generate_parameters( raise FileExistsError( f"Parameters {name!r} already exists and `allow_overwrite=False`." ) + if format_name in ["yml", "yaml"]: write_dict(parameters, file_name=parameter_file, offset=0) elif format_name == "csv": - parameter_group = ( - ParameterGroup.from_dict(parameters) + save_parameters( + Parameters.from_dict(parameters) if isinstance(parameters, dict) - else ParameterGroup.from_list(parameters) + else Parameters.from_list(parameters), + parameter_file, + allow_overwrite=allow_overwrite, ) - save_parameters(parameter_group, parameter_file, allow_overwrite=allow_overwrite) diff --git a/glotaran/project/result.py b/glotaran/project/result.py index 15c6ec0ee..61cbd2ab9 100644 --- a/glotaran/project/result.py +++ b/glotaran/project/result.py @@ -20,8 +20,8 @@ from glotaran.io import save_result from glotaran.model import Model from glotaran.optimization.optimization_history import OptimizationHistory -from glotaran.parameter import ParameterGroup from glotaran.parameter import ParameterHistory +from glotaran.parameter import Parameters from glotaran.project.dataclass_helpers import exclude_from_dict_field from glotaran.project.dataclass_helpers import file_loadable_field from glotaran.project.dataclass_helpers import init_file_loadable_fields @@ -59,12 +59,12 @@ class Result: scheme: Scheme = file_loadable_field(Scheme) # type:ignore[type-var] - initial_parameters: ParameterGroup = file_loadable_field( # type:ignore[type-var] - ParameterGroup + initial_parameters: Parameters = file_loadable_field( # type:ignore[type-var] + Parameters ) - optimized_parameters: ParameterGroup = file_loadable_field( # type:ignore[type-var] - ParameterGroup + optimized_parameters: Parameters = file_loadable_field( # type:ignore[type-var] + Parameters ) parameter_history: ParameterHistory = file_loadable_field( # type:ignore[type-var] diff --git a/glotaran/project/scheme.py b/glotaran/project/scheme.py index ac2b1afef..7e47cc281 100644 --- a/glotaran/project/scheme.py +++ b/glotaran/project/scheme.py @@ -3,14 +3,11 @@ from dataclasses import dataclass from dataclasses import field -from dataclasses import fields from typing import TYPE_CHECKING -from glotaran.deprecation import warn_deprecated from glotaran.io import load_scheme from glotaran.model import Model -from glotaran.parameter import ParameterGroup -from glotaran.project.dataclass_helpers import exclude_from_dict_field +from glotaran.parameter import Parameters from glotaran.project.dataclass_helpers import file_loadable_field from glotaran.project.dataclass_helpers import init_file_loadable_fields from glotaran.utils.io import DatasetMapping @@ -35,7 +32,7 @@ class Scheme: """ model: Model = file_loadable_field(Model) # type:ignore[type-var] - parameters: ParameterGroup = file_loadable_field(ParameterGroup) # type:ignore[type-var] + parameters: Parameters = file_loadable_field(Parameters) # type:ignore[type-var] data: Mapping[str, xr.Dataset] = file_loadable_field( DatasetMapping, is_wrapper_class=True ) # type:ignore[type-var] @@ -44,9 +41,6 @@ class Scheme: clp_link_method: Literal["nearest", "backward", "forward"] = "nearest" maximum_number_function_evaluations: int | None = None - non_negative_least_squares: bool | None = exclude_from_dict_field(None) - group_tolerance: float | None = exclude_from_dict_field(None) - group: bool | None = exclude_from_dict_field(None) add_svd: bool = True ftol: float = 1e-8 gtol: float = 1e-8 @@ -68,59 +62,6 @@ def __post_init__(self): """Override attributes after initialization.""" init_file_loadable_fields(self) - # Deprecations - if self.non_negative_least_squares is not None: - warn_deprecated( - deprecated_qual_name_usage=( - "glotaran.project.Scheme(..., non_negative_least_squares=...)" - ), - new_qual_name_usage="dataset_groups.default.residual_function", - to_be_removed_in_version="0.7.0", - check_qual_names=(True, False), - stacklevel=4, - ) - - default_group = self.model.dataset_group_models["default"] - if self.non_negative_least_squares is True: - default_group.residual_function = "non_negative_least_squares" - else: - default_group.residual_function = "variable_projection" - for field_item in fields(self): - if field_item.name == "non_negative_least_squares": - field_item.metadata = {} - - if self.group is not None: - warn_deprecated( - deprecated_qual_name_usage="glotaran.project.Scheme(..., group=...)", - new_qual_name_usage="dataset_groups.default.link_clp", - to_be_removed_in_version="0.7.0", - check_qual_names=(True, False), - stacklevel=4, - ) - self.model.dataset_group_models["default"].link_clp = self.group - for field_item in fields(self): - if field_item.name == "group": - field_item.metadata = {} - - if self.group_tolerance is not None: - warn_deprecated( - deprecated_qual_name_usage="glotaran.project.Scheme(..., group_tolerance=...)", - new_qual_name_usage="glotaran.project.Scheme(..., clp_link_tolerance=...)", - to_be_removed_in_version="0.7.0", - stacklevel=4, - ) - self.clp_link_tolerance = self.group_tolerance - - def problem_list(self) -> list[str]: - """Return a list with all problems in the model and missing parameters. - - Returns - ------- - list[str] - A list of all problems found in the scheme's model. - """ - return self.model.problem_list(self.parameters) - def validate(self) -> MarkdownStr: """Return a string listing all problems in the model and missing parameters. @@ -153,8 +94,6 @@ def markdown(self): model_markdown_str = self.model.markdown(parameters=self.parameters) markdown_str = "\n\n__Scheme__\n\n" - if self.non_negative_least_squares is not None: - markdown_str += f"* *non_negative_least_squares*: {self.non_negative_least_squares}\n" markdown_str += ( "* *maximum_number_function_evaluations*: " f"{self.maximum_number_function_evaluations}\n" diff --git a/glotaran/project/test/test_project.py b/glotaran/project/test/test_project.py index f36d45b20..7a6758b50 100644 --- a/glotaran/project/test/test_project.py +++ b/glotaran/project/test/test_project.py @@ -148,7 +148,7 @@ def test_generate_parameters( for parameter in model.get_parameter_labels(): assert parameters.has(parameter) - assert len(list(filter(lambda p: p[0].startswith("rates"), parameters.all()))) == 5 + assert len(list(filter(lambda p: p.label.startswith("rates"), parameters.all()))) == 5 with pytest.raises(FileExistsError) as exc_info: project.generate_parameters("test_model", parameters_name=name, format_name=fmt) @@ -318,7 +318,7 @@ def test_generators_allow_overwrite(project_folder: Path, project_file: Path): parameters = load_parameters(parameter_file) - assert len(list(filter(lambda p: p[0].startswith("rates"), parameters.all()))) == 5 + assert len(list(filter(lambda p: p.label.startswith("rates"), parameters.all()))) == 5 project.generate_model( "test_model", "decay_parallel", {"nr_compartments": 3}, allow_overwrite=True @@ -335,7 +335,7 @@ def test_generators_allow_overwrite(project_folder: Path, project_file: Path): project.generate_parameters("test", allow_overwrite=True) parameters = load_parameters(parameter_file) - assert len(list(filter(lambda p: p[0].startswith("rates"), parameters.all()))) == 3 + assert len(list(filter(lambda p: p.label.startswith("rates"), parameters.all()))) == 3 def test_missing_file_errors(tmp_path: Path): diff --git a/glotaran/simulation/simulation.py b/glotaran/simulation/simulation.py index 950d9d32d..e2972cd1a 100644 --- a/glotaran/simulation/simulation.py +++ b/glotaran/simulation/simulation.py @@ -7,19 +7,23 @@ import xarray as xr from glotaran.model import DatasetModel +from glotaran.model.dataset_model import get_dataset_model_model_dimension +from glotaran.model.dataset_model import has_dataset_model_global_model +from glotaran.model.dataset_model import is_dataset_model_index_dependent +from glotaran.model.item import fill_item from glotaran.optimization.matrix_provider import MatrixProvider if TYPE_CHECKING: from numpy.typing import ArrayLike from glotaran.model import Model - from glotaran.parameter import ParameterGroup + from glotaran.parameter import Parameters def simulate( model: Model, dataset: str, - parameters: ParameterGroup, + parameters: Parameters, coordinates: dict[str, ArrayLike], clp: xr.DataArray | None = None, noise: bool = False, @@ -34,8 +38,8 @@ def simulate( The model containing the dataset model. dataset : str Label of the dataset to simulate - parameters : ParameterGroup - The parameters for the simulation, organized in a `ParameterGroup`. + parameters : Parameters + The parameters for the simulation. coordinates : dict[str, ArrayLike] A dictionary with the coordinates used for simulation (e.g. time, wavelengths, ...). clp : xr.DataArray | None @@ -59,13 +63,13 @@ def simulate( ValueError Raised if dataset model has no global megacomplex and no clp are provided. """ - dataset_model = model.dataset[dataset].fill(model, parameters) # type:ignore[attr-defined] - model_dimension = dataset_model.get_model_dimension() + dataset_model = fill_item(model.dataset[dataset], model, parameters) + model_dimension = get_dataset_model_model_dimension(dataset_model) model_axis = coordinates[model_dimension] global_dimension = next(dim for dim in coordinates if dim != model_dimension) global_axis = coordinates[global_dimension] - if dataset_model.has_global_model(): + if has_dataset_model_global_model(dataset_model): result = simulate_full_model( dataset_model, global_dimension, global_axis, model_dimension, model_axis ) @@ -126,10 +130,12 @@ def simulate_from_clp( raise ValueError("Missing coordinate 'clp_label' in clp.") matrices = ( [ - MatrixProvider.calculate_dataset_matrix(dataset_model, index, global_axis, model_axis) + MatrixProvider.calculate_dataset_matrix( + dataset_model, index, np.array(global_axis), model_axis + ) for index, _ in enumerate(global_axis) ] - if dataset_model.is_index_dependent() + if is_dataset_model_index_dependent(dataset_model) else [ MatrixProvider.calculate_dataset_matrix(dataset_model, None, global_axis, model_axis) ] @@ -186,8 +192,8 @@ def simulate_full_model( Raised if at least one of the dataset model's global megacomplexes is index dependent. """ if any( - m.index_dependent(dataset_model) # type:ignore[attr-defined] - for m in dataset_model.global_megacomplex + m.index_dependent(dataset_model) # type:ignore[union-attr] + for m in dataset_model.global_megacomplex # type:ignore[union-attr] ): raise ValueError("Index dependent models for global dimension are not supported.") diff --git a/glotaran/simulation/test/test_simulation.py b/glotaran/simulation/test/test_simulation.py index 92972fc53..1e372a43f 100644 --- a/glotaran/simulation/test/test_simulation.py +++ b/glotaran/simulation/test/test_simulation.py @@ -2,18 +2,18 @@ import pytest from glotaran.optimization.test.models import SimpleTestModel -from glotaran.parameter import ParameterGroup +from glotaran.parameter import Parameters from glotaran.simulation import simulate @pytest.mark.parametrize("index_dependent", [True, False]) @pytest.mark.parametrize("noise", [True, False]) def test_simulate_dataset(index_dependent, noise): - model = SimpleTestModel.from_dict( - { + model = SimpleTestModel( + **{ "megacomplex": { - "m1": {"is_index_dependent": index_dependent}, - "m2": {"is_index_dependent": False}, + "m1": {"type": "simple-test-mc", "is_index_dependent": index_dependent}, + "m2": {"type": "simple-test-mc", "is_index_dependent": False}, }, "dataset": { "dataset1": { @@ -26,7 +26,7 @@ def test_simulate_dataset(index_dependent, noise): print(model.validate()) assert model.valid() - parameter = ParameterGroup.from_list([1, 1]) + parameter = Parameters.from_list([1, 1]) print(model.validate(parameter)) assert model.valid(parameter) diff --git a/glotaran/testing/test/test_plugin_system.py b/glotaran/testing/test/test_plugin_system.py index df1b2d9b7..a0e33da9e 100644 --- a/glotaran/testing/test/test_plugin_system.py +++ b/glotaran/testing/test/test_plugin_system.py @@ -13,7 +13,7 @@ from glotaran.testing.plugin_system import monkeypatch_plugin_registry_project_io -@megacomplex(dimension="test") +@megacomplex() class DummyMegacomplex(Megacomplex): pass diff --git a/glotaran/utils/attrs_helper.py b/glotaran/utils/attrs_helper.py new file mode 100644 index 000000000..dc356b3b2 --- /dev/null +++ b/glotaran/utils/attrs_helper.py @@ -0,0 +1,54 @@ +"""Helper functions for attrs.""" +from glotaran.utils.helpers import nan_or_equal + + +def no_default_vals_in_repr(cls): + """Class decorator to omits attributes from repr that have their default value. + + Needs to be on top of the ``attr.define`` decorator. + Based on: https://stackoverflow.com/a/47663099/3990615 + + Parameters + ---------- + cls + Class decorated with ``attr.define``. + + Returns + ------- + type[cls] + """ + defaults = { + attribute.name: attribute.default + for attribute in cls.__attrs_attrs__ + if attribute.repr is True + } + + def repr_(self) -> str: + """Return string representing the instance. + + Parameters + ---------- + self: cls + Instance of the wrapped class. + + Returns + ------- + str + """ + real_cls = self.__class__ + + if (qualname := getattr(real_cls, "__qualname__", None)) is not None: + class_name = qualname.rsplit(">.", 1)[-1] + else: + class_name = real_cls.__name__ + + args_str = ", ".join( + f"{name}={repr(getattr(self, name))}" + for name in defaults + if not nan_or_equal(getattr(self, name), defaults[name]) + ) + + return f"{class_name}({args_str})" + + cls.__repr__ = repr_ + return cls diff --git a/glotaran/utils/helpers.py b/glotaran/utils/helpers.py new file mode 100644 index 000000000..6668b02d3 --- /dev/null +++ b/glotaran/utils/helpers.py @@ -0,0 +1,27 @@ +"""Module containing general helper functions.""" + +from typing import Any + +import numpy as np + + +def nan_or_equal(lhs: Any, rhs: Any) -> bool: + """Compare values which can be nan for equality. + + This helper function is needed because ``np.nan == np.nan`` returns ``False``. + + Parameters + ---------- + lhs: Any + Left hand side value. + rhs: Any + Right hand side value. + + Returns + ------- + bool + Whether or not values are equal. + """ + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return (np.isnan(lhs) and np.isnan(rhs)) or lhs == rhs + return lhs == rhs diff --git a/glotaran/utils/test/test_helpers.py b/glotaran/utils/test/test_helpers.py new file mode 100644 index 000000000..2f53aab4e --- /dev/null +++ b/glotaran/utils/test/test_helpers.py @@ -0,0 +1,23 @@ +"""Tests for ``glotaran.utils.numeric_helpers``.""" + +from typing import Any + +import numpy as np +import pytest + +from glotaran.utils.helpers import nan_or_equal + + +@pytest.mark.parametrize( + "lhs, rhs, expected", + ( + ("foo", "foo", True), + (np.nan, np.nan, True), + (1, 1, True), + (1, 2, False), + ("foo", "bar", False), + ), +) +def test_nan_or_equal(lhs: Any, rhs: Any, expected: bool): + """Only ``False`` if values actually differ.""" + assert nan_or_equal(lhs, rhs) == expected diff --git a/readthedocs.yml b/readthedocs.yml index 76230fee1..d2ecabe62 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -2,11 +2,15 @@ version: 2 formats: all +sphinx: + configuration: docs/source/conf.py + build: - image: latest + os: ubuntu-22.04 + tools: + python: "3.10" python: - version: 3.8 install: - requirements: docs/requirements.txt - method: pip diff --git a/requirements_dev.txt b/requirements_dev.txt index a5df586ec..9a76c1f46 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,6 +4,7 @@ wheel>=0.30.0 # glotaran setup dependencies asteval==0.9.27 +attrs == 22.1.0 click==8.1.3 netCDF4==1.6.1 numba==0.56.3 diff --git a/setup.cfg b/setup.cfg index 47b666247..55309008c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,8 +17,6 @@ classifiers = Natural Language :: English Operating System :: OS Independent Programming Language :: Python :: 3 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Chemistry @@ -34,6 +32,7 @@ project_urls = packages = find: install_requires = asteval>=0.9.22 + attrs>=22.1.0 click>=8.1.3 netCDF4>=1.5.3 numba>=0.52 @@ -48,7 +47,7 @@ install_requires = setuptools>=41.2 tabulate>=0.8.8 xarray>=2022.3.0 -python_requires = >=3.8, <3.11 +python_requires = >=3.10, <3.11 setup_requires = setuptools>=58.0.4 tests_require = pytest @@ -128,7 +127,7 @@ ignore_errors = False [mypy-glotaran.simulation.*] ignore_errors = False -[mypy-glotaran.model.property] +[mypy-glotaran.model.*] ignore_errors = False [mypy-glotaran.builtin.io.pandas.*]