From 6051f9c2362823c09ec4c57a7c8066620a4c2516 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Mon, 27 Jan 2025 14:12:01 +0100 Subject: [PATCH] Improve ConfigData & related (#101) * [wip] first pass at ConfigData and related classes * replace canonical data models with improved named base * update doctests and unit tests --- src/sirocco/core/graph_items.py | 7 +- src/sirocco/parsing/_yaml_data_models.py | 159 +++++++++++++----- tests/unit_tests/core/test_workflow.py | 8 +- .../parsing/test_yaml_data_models.py | 31 +++- 4 files changed, 156 insertions(+), 49 deletions(-) diff --git a/src/sirocco/core/graph_items.py b/src/sirocco/core/graph_items.py index 99b01851..2453d2b3 100644 --- a/src/sirocco/core/graph_items.py +++ b/src/sirocco/core/graph_items.py @@ -15,7 +15,12 @@ from datetime import datetime from pathlib import Path - from sirocco.parsing._yaml_data_models import ConfigBaseData, ConfigCycleTask, ConfigTask, TargetNodesBaseModel + from sirocco.parsing._yaml_data_models import ( + ConfigBaseData, + ConfigCycleTask, + ConfigTask, + TargetNodesBaseModel, + ) @dataclass diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 5ec12f40..48568a23 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import itertools import time import typing @@ -25,46 +26,68 @@ class _NamedBaseModel(BaseModel): - """Base class for all classes with a key that specifies their name. - - For example: + """ + Base model for reading names from yaml keys *or* keyword args to the constructor. - .. yaml + Reading from key-value pairs in yaml is also supported in order to enable + the standard constructor usage from Python, as demonstrated in the below + examples. On it's own it is not considered desirable. - - property_name: - property: true + Examples: - When parsing with this as parent class it is converted to - `{"name": "propery_name", "property": True}`. + >>> _NamedBaseModel(name="foo") + _NamedBaseModel(name='foo') + + >>> _NamedBaseModel(foo={}) + _NamedBaseModel(name='foo') + + >>> import pydantic_yaml, textwrap + >>> pydantic_yaml.parse_yaml_raw_as( + ... _NamedBaseModel, + ... textwrap.dedent(''' + ... foo: + ... '''), + ... ) + _NamedBaseModel(name='foo') + + >>> pydantic_yaml.parse_yaml_raw_as( + ... _NamedBaseModel, + ... textwrap.dedent(''' + ... name: foo + ... '''), + ... ) + _NamedBaseModel(name='foo') """ name: str - def __init__(self, /, **data): - super().__init__(**self.merge_name_and_specs(data)) - - @staticmethod - def merge_name_and_specs(data: dict) -> dict: - """ - Converts dict of form - - `{my_name: {'spec_0': ..., ..., 'spec_n': ...}` - - to - - `{'name': my_name, 'spec_0': ..., ..., 'spec_n': ...}` + @model_validator(mode="before") + @classmethod + def reformat_named_object(cls, data: Any) -> Any: + return cls.extract_merge_name(data) - by copy. - """ - name_and_spec = {} - if len(data) != 1: - msg = f"Expected dict with one element of the form {{'name': specification}} but got {data}." - raise ValueError(msg) - name_and_spec["name"] = next(iter(data.keys())) - # if no specification specified e.g. "- my_name:" - if (spec := next(iter(data.values()))) is not None: - name_and_spec.update(spec) - return name_and_spec + @classmethod + def extract_merge_name(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + if len(data) == 1: + key, value = next(iter(data.items())) + match key: + case str(): + match value: + case str() if key == "name": + pass + case dict() if "name" not in value: + data = value | {"name": key} + case None: + data = {"name": key} + case _: + msg = f"{cls.__name__} may only be used for named objects, not values (got {data})." + raise TypeError(msg) + case _: + msg = f"{cls.__name__} requires name to be a str (got {key})." + raise TypeError(msg) + return data class _WhenBaseModel(BaseModel): @@ -434,10 +457,15 @@ def check_nml(cls, nml_list: list[Any]) -> ConfigNamelist: return namelists +class DataType(enum.StrEnum): + FILE = enum.auto() + DIR = enum.auto() + + @dataclass class ConfigBaseDataSpecs: - type: str | None = None - src: str | None = None + type: DataType + src: str format: str | None = None computer: str | None = None @@ -445,12 +473,31 @@ class ConfigBaseDataSpecs: class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs): """ To create an instance of a data defined in a workflow file. + + Examples: + + yaml snippet: + + >>> import textwrap + >>> import pydantic_yaml + >>> snippet = textwrap.dedent( + ... ''' + ... foo: + ... type: "file" + ... src: "foo.txt" + ... ''' + ... ) + >>> pydantic_yaml.parse_yaml_raw_as(ConfigBaseData, snippet) + ConfigBaseData(type=, src='foo.txt', format=None, computer=None, name='foo', parameters=[]) + + + from python: + + >>> ConfigBaseData(name="foo", type=DataType.FILE, src="foo.txt") + ConfigBaseData(type=, src='foo.txt', format=None, computer=None, name='foo', parameters=[]) """ parameters: list[str] = [] - type: str | None = None - src: str | None = None - format: str | None = None @field_validator("type") @classmethod @@ -478,7 +525,36 @@ def invalid_field(cls, value: str | None) -> str | None: class ConfigData(BaseModel): - """To create the container of available and generated data""" + """ + To create the container of available and generated data + + Example: + + yaml snippet: + + >>> import textwrap + >>> import pydantic_yaml + >>> snippet = textwrap.dedent( + ... ''' + ... available: + ... - foo: + ... type: "file" + ... src: "foo.txt" + ... generated: + ... - bar: + ... type: "file" + ... src: "bar.txt" + ... ''' + ... ) + >>> data = pydantic_yaml.parse_yaml_raw_as(ConfigData, snippet) + >>> assert data.available[0].name == "foo" + >>> assert data.generated[0].name == "bar" + + from python: + + >>> ConfigData() + ConfigData(available=[], generated=[]) + """ available: list[ConfigAvailableData] = [] generated: list[ConfigGeneratedData] = [] @@ -489,7 +565,7 @@ def get_plugin_from_named_base_model( ) -> str: if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)): return data.plugin - name_and_specs = _NamedBaseModel.merge_name_and_specs(data) + name_and_specs = ConfigBaseTask.extract_merge_name(data) if name_and_specs.get("name", None) == "ROOT": return ConfigRootTask.plugin plugin = name_and_specs.get("plugin", None) @@ -529,8 +605,12 @@ class ConfigWorkflow(BaseModel): ... data: ... available: ... - foo: + ... type: "file" + ... src: "foo.txt" ... generated: ... - bar: + ... type: "file" + ... src: some_task_output ... ''' ... ) >>> wf = pydantic_yaml.parse_yaml_raw_as(ConfigWorkflow, config) @@ -633,4 +713,3 @@ def load_workflow_config(workflow_config: str) -> CanonicalWorkflow: rootdir = config_path.resolve().parent return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir) - # return parsed_workflow diff --git a/tests/unit_tests/core/test_workflow.py b/tests/unit_tests/core/test_workflow.py index 188bb85e..40b45d06 100644 --- a/tests/unit_tests/core/test_workflow.py +++ b/tests/unit_tests/core/test_workflow.py @@ -9,11 +9,11 @@ def test_minimal_workflow(): minimal_config = models.CanonicalWorkflow( name="minimal", rootdir=pathlib.Path("minimal"), - cycles=[models.ConfigCycle(some_cycle={"tasks": []})], - tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})], + cycles=[models.ConfigCycle(name="some_cycle", tasks=[])], + tasks=[models.ConfigShellTask(name="some_task", plugin="shell")], data=models.ConfigData( - available=[models.ConfigAvailableData(foo={})], - generated=[models.ConfigGeneratedData(bar={})], + available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")], + generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")], ), parameters={}, ) diff --git a/tests/unit_tests/parsing/test_yaml_data_models.py b/tests/unit_tests/parsing/test_yaml_data_models.py index 6d827eba..7bc695a1 100644 --- a/tests/unit_tests/parsing/test_yaml_data_models.py +++ b/tests/unit_tests/parsing/test_yaml_data_models.py @@ -1,17 +1,36 @@ import pathlib import textwrap +import pydantic +import pytest + from sirocco.parsing import _yaml_data_models as models +@pytest.mark.parametrize("data_type", ["file", "dir"]) +def test_base_data(data_type): + testee = models.ConfigBaseData(name="name", type=data_type, src="foo.txt", format=None) + + assert testee.type == data_type + + +@pytest.mark.parametrize("data_type", [None, "invalid", 1.42]) +def test_base_data_invalid_type(data_type): + with pytest.raises(pydantic.ValidationError): + _ = models.ConfigBaseData(name="name", src="foo", format="nml") + + with pytest.raises(pydantic.ValidationError): + _ = models.ConfigBaseData(name="name", type=data_type, src="foo", format="nml") + + def test_workflow_canonicalization(): config = models.ConfigWorkflow( name="testee", - cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})], - tasks=[{"some_task": {"plugin": "shell"}}], + cycles=[models.ConfigCycle(name="minimal", tasks=[models.ConfigCycleTask(name="a")])], + tasks=[models.ConfigShellTask(name="some_task")], data=models.ConfigData( - available=[models.ConfigAvailableData(foo={})], - generated=[models.ConfigGeneratedData(bar={})], + available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")], + generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")], ), ) @@ -34,8 +53,12 @@ def test_load_workflow_config(tmp_path): data: available: - c: + type: "file" + src: "c.txt" generated: - d: + type: "dir" + src: "d" """ ) minimal = tmp_path / "minimal.yml"