Skip to content

Commit

Permalink
Improve ConfigData & related (#101)
Browse files Browse the repository at this point in the history
* [wip] first pass at ConfigData and related classes

* replace canonical data models with improved named base

* update doctests and unit tests
  • Loading branch information
DropD authored Jan 27, 2025
1 parent 7b0cfa3 commit 6051f9c
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 49 deletions.
7 changes: 6 additions & 1 deletion src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from datetime import datetime
from pathlib import Path

from sirocco.parsing._yaml_data_models import ConfigBaseData, ConfigCycleTask, ConfigTask, TargetNodesBaseModel
from sirocco.parsing._yaml_data_models import (
ConfigBaseData,
ConfigCycleTask,
ConfigTask,
TargetNodesBaseModel,
)


@dataclass
Expand Down
159 changes: 119 additions & 40 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import itertools
import time
import typing
Expand All @@ -25,46 +26,68 @@


class _NamedBaseModel(BaseModel):
"""Base class for all classes with a key that specifies their name.
For example:
"""
Base model for reading names from yaml keys *or* keyword args to the constructor.
.. yaml
Reading from key-value pairs in yaml is also supported in order to enable
the standard constructor usage from Python, as demonstrated in the below
examples. On it's own it is not considered desirable.
- property_name:
property: true
Examples:
When parsing with this as parent class it is converted to
`{"name": "propery_name", "property": True}`.
>>> _NamedBaseModel(name="foo")
_NamedBaseModel(name='foo')
>>> _NamedBaseModel(foo={})
_NamedBaseModel(name='foo')
>>> import pydantic_yaml, textwrap
>>> pydantic_yaml.parse_yaml_raw_as(
... _NamedBaseModel,
... textwrap.dedent('''
... foo:
... '''),
... )
_NamedBaseModel(name='foo')
>>> pydantic_yaml.parse_yaml_raw_as(
... _NamedBaseModel,
... textwrap.dedent('''
... name: foo
... '''),
... )
_NamedBaseModel(name='foo')
"""

name: str

def __init__(self, /, **data):
super().__init__(**self.merge_name_and_specs(data))

@staticmethod
def merge_name_and_specs(data: dict) -> dict:
"""
Converts dict of form
`{my_name: {'spec_0': ..., ..., 'spec_n': ...}`
to
`{'name': my_name, 'spec_0': ..., ..., 'spec_n': ...}`
@model_validator(mode="before")
@classmethod
def reformat_named_object(cls, data: Any) -> Any:
return cls.extract_merge_name(data)

by copy.
"""
name_and_spec = {}
if len(data) != 1:
msg = f"Expected dict with one element of the form {{'name': specification}} but got {data}."
raise ValueError(msg)
name_and_spec["name"] = next(iter(data.keys()))
# if no specification specified e.g. "- my_name:"
if (spec := next(iter(data.values()))) is not None:
name_and_spec.update(spec)
return name_and_spec
@classmethod
def extract_merge_name(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
case str():
match value:
case str() if key == "name":
pass
case dict() if "name" not in value:
data = value | {"name": key}
case None:
data = {"name": key}
case _:
msg = f"{cls.__name__} may only be used for named objects, not values (got {data})."
raise TypeError(msg)
case _:
msg = f"{cls.__name__} requires name to be a str (got {key})."
raise TypeError(msg)
return data


class _WhenBaseModel(BaseModel):
Expand Down Expand Up @@ -434,23 +457,47 @@ def check_nml(cls, nml_list: list[Any]) -> ConfigNamelist:
return namelists


class DataType(enum.StrEnum):
FILE = enum.auto()
DIR = enum.auto()


@dataclass
class ConfigBaseDataSpecs:
type: str | None = None
src: str | None = None
type: DataType
src: str
format: str | None = None
computer: str | None = None


class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs):
"""
To create an instance of a data defined in a workflow file.
Examples:
yaml snippet:
>>> import textwrap
>>> import pydantic_yaml
>>> snippet = textwrap.dedent(
... '''
... foo:
... type: "file"
... src: "foo.txt"
... '''
... )
>>> pydantic_yaml.parse_yaml_raw_as(ConfigBaseData, snippet)
ConfigBaseData(type=<DataType.FILE: 'file'>, src='foo.txt', format=None, computer=None, name='foo', parameters=[])
from python:
>>> ConfigBaseData(name="foo", type=DataType.FILE, src="foo.txt")
ConfigBaseData(type=<DataType.FILE: 'file'>, src='foo.txt', format=None, computer=None, name='foo', parameters=[])
"""

parameters: list[str] = []
type: str | None = None
src: str | None = None
format: str | None = None

@field_validator("type")
@classmethod
Expand Down Expand Up @@ -478,7 +525,36 @@ def invalid_field(cls, value: str | None) -> str | None:


class ConfigData(BaseModel):
"""To create the container of available and generated data"""
"""
To create the container of available and generated data
Example:
yaml snippet:
>>> import textwrap
>>> import pydantic_yaml
>>> snippet = textwrap.dedent(
... '''
... available:
... - foo:
... type: "file"
... src: "foo.txt"
... generated:
... - bar:
... type: "file"
... src: "bar.txt"
... '''
... )
>>> data = pydantic_yaml.parse_yaml_raw_as(ConfigData, snippet)
>>> assert data.available[0].name == "foo"
>>> assert data.generated[0].name == "bar"
from python:
>>> ConfigData()
ConfigData(available=[], generated=[])
"""

available: list[ConfigAvailableData] = []
generated: list[ConfigGeneratedData] = []
Expand All @@ -489,7 +565,7 @@ def get_plugin_from_named_base_model(
) -> str:
if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)):
return data.plugin
name_and_specs = _NamedBaseModel.merge_name_and_specs(data)
name_and_specs = ConfigBaseTask.extract_merge_name(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigRootTask.plugin
plugin = name_and_specs.get("plugin", None)
Expand Down Expand Up @@ -529,8 +605,12 @@ class ConfigWorkflow(BaseModel):
... data:
... available:
... - foo:
... type: "file"
... src: "foo.txt"
... generated:
... - bar:
... type: "file"
... src: some_task_output
... '''
... )
>>> wf = pydantic_yaml.parse_yaml_raw_as(ConfigWorkflow, config)
Expand Down Expand Up @@ -633,4 +713,3 @@ def load_workflow_config(workflow_config: str) -> CanonicalWorkflow:
rootdir = config_path.resolve().parent

return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir)
# return parsed_workflow
8 changes: 4 additions & 4 deletions tests/unit_tests/core/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def test_minimal_workflow():
minimal_config = models.CanonicalWorkflow(
name="minimal",
rootdir=pathlib.Path("minimal"),
cycles=[models.ConfigCycle(some_cycle={"tasks": []})],
tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})],
cycles=[models.ConfigCycle(name="some_cycle", tasks=[])],
tasks=[models.ConfigShellTask(name="some_task", plugin="shell")],
data=models.ConfigData(
available=[models.ConfigAvailableData(foo={})],
generated=[models.ConfigGeneratedData(bar={})],
available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
),
parameters={},
)
Expand Down
31 changes: 27 additions & 4 deletions tests/unit_tests/parsing/test_yaml_data_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import pathlib
import textwrap

import pydantic
import pytest

from sirocco.parsing import _yaml_data_models as models


@pytest.mark.parametrize("data_type", ["file", "dir"])
def test_base_data(data_type):
testee = models.ConfigBaseData(name="name", type=data_type, src="foo.txt", format=None)

assert testee.type == data_type


@pytest.mark.parametrize("data_type", [None, "invalid", 1.42])
def test_base_data_invalid_type(data_type):
with pytest.raises(pydantic.ValidationError):
_ = models.ConfigBaseData(name="name", src="foo", format="nml")

with pytest.raises(pydantic.ValidationError):
_ = models.ConfigBaseData(name="name", type=data_type, src="foo", format="nml")


def test_workflow_canonicalization():
config = models.ConfigWorkflow(
name="testee",
cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})],
tasks=[{"some_task": {"plugin": "shell"}}],
cycles=[models.ConfigCycle(name="minimal", tasks=[models.ConfigCycleTask(name="a")])],
tasks=[models.ConfigShellTask(name="some_task")],
data=models.ConfigData(
available=[models.ConfigAvailableData(foo={})],
generated=[models.ConfigGeneratedData(bar={})],
available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
),
)

Expand All @@ -34,8 +53,12 @@ def test_load_workflow_config(tmp_path):
data:
available:
- c:
type: "file"
src: "c.txt"
generated:
- d:
type: "dir"
src: "d"
"""
)
minimal = tmp_path / "minimal.yml"
Expand Down

0 comments on commit 6051f9c

Please sign in to comment.