diff --git a/pyproject.toml b/pyproject.toml index c4c2436f..006a92f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "numpy", "isoduration", "pydantic", - "pydantic-yaml", + "ruamel.yaml", "aiida-core>=2.5", "aiida-workgraph==0.4.10", "termcolor", diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 48568a23..4a71b404 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -6,6 +6,7 @@ import typing from dataclasses import dataclass, field from datetime import datetime +from io import StringIO from pathlib import Path from typing import Annotated, Any, ClassVar, Literal @@ -18,9 +19,11 @@ Discriminator, Field, Tag, + TypeAdapter, field_validator, model_validator, ) +from ruamel.yaml import YAML from sirocco.parsing._utils import TimeUtils @@ -41,8 +44,8 @@ class _NamedBaseModel(BaseModel): >>> _NamedBaseModel(foo={}) _NamedBaseModel(name='foo') - >>> import pydantic_yaml, textwrap - >>> pydantic_yaml.parse_yaml_raw_as( + >>> import textwrap + >>> validate_yaml_content( ... _NamedBaseModel, ... textwrap.dedent(''' ... foo: @@ -50,7 +53,7 @@ class _NamedBaseModel(BaseModel): ... ) _NamedBaseModel(name='foo') - >>> pydantic_yaml.parse_yaml_raw_as( + >>> validate_yaml_content( ... _NamedBaseModel, ... textwrap.dedent(''' ... name: foo @@ -479,7 +482,6 @@ class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs): yaml snippet: >>> import textwrap - >>> import pydantic_yaml >>> snippet = textwrap.dedent( ... ''' ... foo: @@ -487,7 +489,7 @@ class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs): ... src: "foo.txt" ... ''' ... ) - >>> pydantic_yaml.parse_yaml_raw_as(ConfigBaseData, snippet) + >>> validate_yaml_content(ConfigBaseData, snippet) ConfigBaseData(type=, src='foo.txt', format=None, computer=None, name='foo', parameters=[]) @@ -533,7 +535,6 @@ class ConfigData(BaseModel): yaml snippet: >>> import textwrap - >>> import pydantic_yaml >>> snippet = textwrap.dedent( ... ''' ... available: @@ -546,7 +547,7 @@ class ConfigData(BaseModel): ... src: "bar.txt" ... ''' ... ) - >>> data = pydantic_yaml.parse_yaml_raw_as(ConfigData, snippet) + >>> data = validate_yaml_content(ConfigData, snippet) >>> assert data.available[0].name == "foo" >>> assert data.generated[0].name == "bar" @@ -592,7 +593,6 @@ class ConfigWorkflow(BaseModel): minimal yaml to generate: >>> import textwrap - >>> import pydantic_yaml >>> config = textwrap.dedent( ... ''' ... cycles: @@ -613,7 +613,7 @@ class ConfigWorkflow(BaseModel): ... src: some_task_output ... ''' ... ) - >>> wf = pydantic_yaml.parse_yaml_raw_as(ConfigWorkflow, config) + >>> wf = validate_yaml_content(ConfigWorkflow, config) minimum programmatically created instance @@ -698,13 +698,11 @@ def load_workflow_config(workflow_config: str) -> CanonicalWorkflow: :param workflow_config: the string to the config yaml file containing the workflow definition """ - from pydantic_yaml import parse_yaml_raw_as - config_path = Path(workflow_config) content = config_path.read_text() - parsed_workflow = parse_yaml_raw_as(ConfigWorkflow, content) + parsed_workflow = validate_yaml_content(ConfigWorkflow, content) # If name was not specified, then we use filename without file extension if parsed_workflow.name is None: @@ -713,3 +711,26 @@ def load_workflow_config(workflow_config: str) -> CanonicalWorkflow: rootdir = config_path.resolve().parent return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir) + + +OBJECT_T = typing.TypeVar("OBJECT_T") + + +def validate_yaml_content(cls: type[OBJECT_T], content: str) -> OBJECT_T: + """Parses the YAML content into a python object using generic types and subsequently validates it with pydantic. + + Args: + cls (type[OBJECT_T]): The class type to which the parsed yaml content should + be validated. It must be compatible with pydantic validation. + content (str): The yaml content as a string. + + Returns: + OBJECT_T: An instance of the specified class type with data parsed and + validated from the YAML content. + + Raises: + pydantic.ValidationError: If the YAML content cannot be validated + against the specified class type. + ruamel.yaml.YAMLError: If there is an error in parsing the YAML content. + """ + return TypeAdapter(cls).validate_python(YAML(typ="safe", pure=True).load(StringIO(content)))