Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy storage #553

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pyiron_workflow/mixin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict:
**run_kwargs,
)

def _before_run(self, /, check_readiness: bool, **kwargs) -> tuple[bool, Any]:
def _before_run(
self, /, check_readiness: bool, *args, **kwargs
) -> tuple[bool, Any]:
"""
Things to do _before_ running.

Expand Down Expand Up @@ -194,6 +196,7 @@ def _run(
run_exception_kwargs: dict,
run_finally_kwargs: dict,
finish_run_kwargs: dict,
*args,
**kwargs,
) -> Any | tuple | Future:
"""
Expand Down Expand Up @@ -254,15 +257,15 @@ def _run(
)
return self.future

def _run_exception(self, /, **kwargs):
def _run_exception(self, /, *args, **kwargs):
"""
What to do if an exception is encountered inside :meth:`_run` or
:meth:`_finish_run.
"""
self.running = False
self.failed = True

def _run_finally(self, /):
def _run_finally(self, /, *args, **kwargs):
"""
What to do after :meth:`_finish_run` (whether an exception is encountered or
not), or in :meth:`_run` after an exception is encountered.
Expand Down
4 changes: 0 additions & 4 deletions pyiron_workflow/nodes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class FromManyInputs(Transformer, ABC):
# Inputs convert to `run_args` as a value dictionary
# This must be commensurate with the internal expectations of _on_run

@abstractmethod
def _on_run(self, **inputs_to_value_dict) -> Any:
"""Must take inputs kwargs"""

@property
def _run_args(self) -> tuple[tuple, dict]:
return (), self.inputs.to_value_dict()
Expand Down
34 changes: 22 additions & 12 deletions pyiron_workflow/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class StorageInterface(ABC):
"""

@abstractmethod
def _save(self, node: Node, filename: Path, /, **kwargs):
def _save(self, node: Node, filename: Path, /, *args, **kwargs):
"""
Save a node to file.

Expand All @@ -48,7 +48,7 @@ def _save(self, node: Node, filename: Path, /, **kwargs):
"""

@abstractmethod
def _load(self, filename: Path, /, **kwargs) -> Node:
def _load(self, filename: Path, /, *args, **kwargs) -> Node:
"""
Instantiate a node from file.

Expand All @@ -61,7 +61,7 @@ def _load(self, filename: Path, /, **kwargs) -> Node:
"""

@abstractmethod
def _has_saved_content(self, filename: Path, /, **kwargs) -> bool:
def _has_saved_content(self, filename: Path, /, *args, **kwargs) -> bool:
"""
Check for a save file matching this storage interface.

Expand All @@ -74,7 +74,7 @@ def _has_saved_content(self, filename: Path, /, **kwargs) -> bool:
"""

@abstractmethod
def _delete(self, filename: Path, /, **kwargs):
def _delete(self, filename: Path, /, *args, **kwargs):
"""
Remove an existing save-file for this backend.

Expand Down Expand Up @@ -132,7 +132,7 @@ def has_saved_content(
node: Node | None = None,
filename: str | Path | None = None,
**kwargs,
):
) -> bool:
"""
Check if a file has contents related to a node.

Expand Down Expand Up @@ -168,7 +168,9 @@ def delete(
if filename.parent.exists() and not any(filename.parent.iterdir()):
filename.parent.rmdir()

def _parse_filename(self, node: Node | None, filename: str | Path | None = None):
def _parse_filename(
self, node: Node | None, filename: str | Path | None = None
) -> Path:
"""
Make sure the node xor filename was provided, and if it's the node, convert it
into a canonical filename by exploiting the node's semantic path.
Expand All @@ -195,6 +197,11 @@ def _parse_filename(self, node: Node | None, filename: str | Path | None = None)
f"Both the node ({node.full_label}) and filename ({filename}) were "
f"specified for loading -- please only specify one or the other."
)
else:
raise AssertionError(
"This is an unreachable state -- we have covered all four cases of the "
"boolean `is (not) None` square."
)


class PickleStorage(StorageInterface):
Expand All @@ -204,11 +211,11 @@ class PickleStorage(StorageInterface):
def __init__(self, cloudpickle_fallback: bool = True):
self.cloudpickle_fallback = cloudpickle_fallback

def _fallback(self, cpf: bool | None):
def _fallback(self, cpf: bool | None) -> bool:
return self.cloudpickle_fallback if cpf is None else cpf

def _save(
self, node: Node, filename: Path, cloudpickle_fallback: bool | None = None
self, node: Node, filename: Path, /, cloudpickle_fallback: bool | None = None
):
if not self._fallback(cloudpickle_fallback) and not node.import_ready:
raise TypeNotFoundError(
Expand Down Expand Up @@ -236,19 +243,22 @@ def _save(
if e is not None:
raise e

def _load(self, filename: Path, cloudpickle_fallback: bool | None = None) -> Node:
def _load(
self, filename: Path, /, cloudpickle_fallback: bool | None = None
) -> Node:
attacks = [(self._PICKLE, pickle.load)]
if self._fallback(cloudpickle_fallback):
attacks += [(self._CLOUDPICKLE, cloudpickle.load)]

for suffix, load_method in attacks:
p = filename.with_suffix(suffix)
if p.exists():
if p.is_file():
with open(p, "rb") as filehandle:
inst = load_method(filehandle)
return inst
raise FileNotFoundError(f"Could not load {filename}, no such file found.")

def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None):
def _delete(self, filename: Path, /, cloudpickle_fallback: bool | None = None):
suffixes = (
[self._PICKLE, self._CLOUDPICKLE]
if self._fallback(cloudpickle_fallback)
Expand All @@ -258,7 +268,7 @@ def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None):
filename.with_suffix(suffix).unlink(missing_ok=True)

def _has_saved_content(
self, filename: Path, cloudpickle_fallback: bool | None = None
self, filename: Path, /, cloudpickle_fallback: bool | None = None
) -> bool:
suffixes = (
[self._PICKLE, self._CLOUDPICKLE]
Expand Down
12 changes: 12 additions & 0 deletions pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class ParentMostError(TypeError):
"""


class NoArgsError(TypeError):
"""
To be raised when *args can't be processed but are received
"""


class Workflow(Composite):
"""
Workflows are a dynamic composite node -- i.e. they hold and run a collection of
Expand Down Expand Up @@ -361,12 +367,18 @@ def _before_run(

def run(
self,
*args,
check_readiness: bool = True,
**kwargs,
):
# Note: Workflows may have neither parents nor siblings, so we don't need to
# worry about running their data trees first, fetching their input, nor firing
# their `ran` signal, hence the change in signature from Node.run
if len(args) > 0:
raise NoArgsError(
f"{self.__class__} does not know how to process *args on run, but "
f"received {args}"
)

return super().run(
run_data_tree=False,
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyiron_workflow._tests import ensure_tests_in_python_path
from pyiron_workflow.channels import NOT_DATA
from pyiron_workflow.storage import TypeNotFoundError, available_backends
from pyiron_workflow.workflow import ParentMostError, Workflow
from pyiron_workflow.workflow import NoArgsError, ParentMostError, Workflow

ensure_tests_in_python_path()

Expand Down Expand Up @@ -258,6 +258,12 @@ def sum_(a, b):
return a + b

wf.sum = sum_(wf.a, wf.b)
with self.assertRaises(
NoArgsError,
msg="Workflows don't know what to do with raw args, since their input "
"has no intrinsic order",
):
wf.run(1, 2)
wf.run()
self.assertEqual(
wf.a.outputs.y.value + wf.b.outputs.y.value,
Expand Down
Loading