Skip to content

Commit

Permalink
from_yaml to from_config_file
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Jan 30, 2025
1 parent 9a5f3c4 commit a719ac9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]:
yield date

@classmethod
def from_yaml(cls: type[Self], config_file_path: str) -> Self:
def from_config_file(cls: type[Self], config_path: str) -> Self:
"""
Loads a python representation of a workflow config file.
:param config_file_path: the string to the config yaml file containing the workflow definition
:param config_path: the string to the config yaml file containing the workflow definition
"""
return cls.from_config_workflow(ConfigWorkflow.from_config_file(config_filename))
return cls.from_config_workflow(ConfigWorkflow.from_config_file(config_path))

@classmethod
def from_config_workflow(cls: type[Self], config_workflow: ConfigWorkflow) -> Workflow:
Expand Down
12 changes: 6 additions & 6 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,24 +681,24 @@ def check_parameters(self) -> ConfigWorkflow:
return self

@classmethod
def from_config_file(cls, config_file_path: str) -> 'ConfigWorkflow':
def from_config_file(cls, config_path: str) -> ConfigWorkflow:
"""Creates a ConfigWorkflow instance from a config file, a yaml with the workflow definition.
Args:
config_file_path (str): The path of the config file to load from
config_path (str): The path of the config file to load from
Returns:
OBJECT_T: An instance of the specified class type with data parsed and
validated from the YAML content.
"""
config_path = Path(config_file_path)
content = config_path.read_text()
config_path_ = Path(config_path)
content = config_path_.read_text()
reader = YAML(typ="safe", pure=True)
object_ = reader.load(StringIO(content))
# If name was not specified, then we use filename without file extension
if "name" not in object_:
object_["name"] = config_path.stem
object_["rootdir"] = config_path.resolve().parent
object_["name"] = config_path_.stem
object_["rootdir"] = config_path_.resolve().parent
adapter = TypeAdapter(cls)
return adapter.validate_python(object_)

Expand Down
4 changes: 2 additions & 2 deletions src/sirocco/vizgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def from_core_workflow(cls, workflow: Workflow):
return cls(workflow.name, workflow.cycles, workflow.data)

@classmethod
def from_yaml(cls, config_path: str):
return cls.from_core_workflow(Workflow.from_yaml(config_path))
def from_config_file(cls, config_path: str):
return cls.from_core_workflow(Workflow.from_config_file(config_path))
12 changes: 6 additions & 6 deletions tests/test_wc_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def config_paths(request):

def test_parse_config_file(config_paths, pprinter):
reference_str = config_paths["txt"].read_text()
test_str = pprinter.format(Workflow.from_yaml(config_paths["yml"]))
test_str = pprinter.format(Workflow.from_config_file(config_paths["yml"]))
if test_str != reference_str:
new_path = Path(config_paths["txt"]).with_suffix(".new.txt")
new_path.write_text(test_str)
Expand All @@ -61,11 +61,11 @@ def test_parse_config_file(config_paths, pprinter):

@pytest.mark.skip(reason="don't run it each time, uncomment to regenerate serilaized data")
def test_serialize_workflow(config_paths, pprinter):
config_paths["txt"].write_text(pprinter.format(Workflow.from_yaml(config_paths["yml"])))
config_paths["txt"].write_text(pprinter.format(Workflow.from_config_file(config_paths["yml"])))


def test_vizgraph(config_paths):
VizGraph.from_yaml(config_paths["yml"]).draw(file_path=config_paths["svg"])
VizGraph.from_config_file(config_paths["yml"]).draw(file_path=config_paths["svg"])


# configs that are tested for running workgraph
Expand All @@ -85,7 +85,7 @@ def test_run_workgraph(config_path, aiida_computer):
# some configs reference computer "localhost" which we need to create beforehand
aiida_computer("localhost").store()

core_workflow = Workflow.from_yaml(config_path)
core_workflow = Workflow.from_config_file(config_path)
aiida_workflow = AiidaWorkGraph(core_workflow)
out = aiida_workflow.run()
assert out.get("execution_count", None).value == 1
Expand All @@ -98,7 +98,7 @@ def test_run_workgraph(config_path, aiida_computer):
)
def test_nml_mod(config_paths, tmp_path):
nml_refdir = config_paths["txt"].parent / "ICON_namelists"
wf = Workflow.from_yaml(config_paths["yml"])
wf = Workflow.from_config_file(config_paths["yml"])
# Create core mamelists
for task in wf.tasks:
if isinstance(task, IconTask):
Expand All @@ -121,7 +121,7 @@ def test_nml_mod(config_paths, tmp_path):
)
def test_serialize_nml(config_paths):
nml_refdir = config_paths["txt"].parent / "ICON_namelists"
wf = Workflow.from_yaml(config_paths["yml"])
wf = Workflow.from_config_file(config_paths["yml"])
for task in wf.tasks:
if isinstance(task, IconTask):
task.create_workflow_namelists(folder=nml_refdir)

0 comments on commit a719ac9

Please sign in to comment.