diff --git a/tests/test_err_msg.py b/tests/test_err_msg.py new file mode 100644 index 0000000..57115b6 --- /dev/null +++ b/tests/test_err_msg.py @@ -0,0 +1,23 @@ +import dataclasses + +import pytest + +import znflow + + +@dataclasses.dataclass +class ComputeMean(znflow.Node): + x: float + y: float + + results: float = None + + def run(self): + self.results = (self.x + self.y) / 2 + + +def test_attribute_access(): + with znflow.DiGraph(): + n1 = ComputeMean(2, 8) + with pytest.raises(znflow.exceptions.ConnectionAttributeError): + n1.x.data diff --git a/znflow/__init__.py b/znflow/__init__.py index 533d87e..5d3c8f8 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -4,6 +4,7 @@ import logging import sys +from znflow import exceptions from znflow.base import ( CombinedConnections, Connection, @@ -31,6 +32,7 @@ "Property", "CombinedConnections", "combine", + "exceptions", ] with contextlib.suppress(ImportError): diff --git a/znflow/base.py b/znflow/base.py index cd5341f..6631965 100644 --- a/znflow/base.py +++ b/znflow/base.py @@ -3,8 +3,11 @@ import contextlib import dataclasses import typing +from typing import Any from uuid import UUID +from znflow import exceptions + @contextlib.contextmanager def disable_graph(*args, **kwargs): @@ -183,6 +186,14 @@ def result(self): result = self.instance return result[self.item] if self.item else result + def __getattribute__(self, __name: str) -> Any: + try: + return super().__getattribute__(__name) + except AttributeError as e: + raise exceptions.ConnectionAttributeError( + "Connection does not support further attributes to its result." + ) from e + @dataclasses.dataclass(frozen=True) class CombinedConnections: diff --git a/znflow/exceptions.py b/znflow/exceptions.py new file mode 100644 index 0000000..227fc47 --- /dev/null +++ b/znflow/exceptions.py @@ -0,0 +1,5 @@ +"""ZnFlow exceptions.""" + + +class ConnectionAttributeError(AttributeError): + """Raised when a connection attribute is not found."""