diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 9db0972e..68c0ec0d 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,10 +15,3 @@ jobs: - uses: astral-sh/ruff-action@v1 with: args: check --select I --fix --diff - ruff-format: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: format --diff diff --git a/pyiron_workflow/__init__.py b/pyiron_workflow/__init__.py index bf693bf9..c6291799 100644 --- a/pyiron_workflow/__init__.py +++ b/pyiron_workflow/__init__.py @@ -27,9 +27,6 @@ - GUI on top for code-lite/code-free visual scripting """ -# deactivate: imported but unused -# flake8: noqa: F401 - from ._version import get_versions __version__ = get_versions()["version"] @@ -37,6 +34,8 @@ # API # User entry point +from pyiron_workflow.workflow import Workflow # ruff: isort: skip + # Node developer entry points from pyiron_workflow.channels import NOT_DATA from pyiron_workflow.find import ( @@ -62,4 +61,3 @@ TypeNotFoundError, available_backends, ) -from pyiron_workflow.workflow import Workflow diff --git a/pyiron_workflow/create.py b/pyiron_workflow/create.py index 82bf5695..cd5c416f 100644 --- a/pyiron_workflow/create.py +++ b/pyiron_workflow/create.py @@ -4,9 +4,7 @@ from __future__ import annotations -from abc import ABC from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from functools import lru_cache from executorlib import Executor as ExecutorlibExecutor from pyiron_snippets.dotdict import DotDict @@ -32,7 +30,6 @@ class Creator(metaclass=Singleton): """ def __init__(self): - # Standard lib self.ProcessPoolExecutor = ProcessPoolExecutor self.ThreadPoolExecutor = ThreadPoolExecutor @@ -44,35 +41,30 @@ def __init__(self): self.function_node = function_node @property - @lru_cache(maxsize=1) def standard(self): from pyiron_workflow.nodes import standard return standard @property - @lru_cache(maxsize=1) def for_node(self): from pyiron_workflow.nodes.for_loop import for_node return for_node @property - @lru_cache(maxsize=1) def macro_node(self): from pyiron_workflow.nodes.macro import macro_node return macro_node @property - @lru_cache(maxsize=1) def Workflow(self): from pyiron_workflow.workflow import Workflow return Workflow @property - @lru_cache(maxsize=1) def meta(self): from pyiron_workflow.nodes.transform import inputs_to_list, list_to_outputs @@ -84,7 +76,6 @@ def meta(self): ) @property - @lru_cache(maxsize=1) def transformer(self): from pyiron_workflow.nodes.transform import ( dataclass_node, @@ -117,21 +108,19 @@ class Wrappers(metaclass=Singleton): as_function_node = staticmethod(as_function_node) @property - @lru_cache(maxsize=1) def as_macro_node(self): from pyiron_workflow.nodes.macro import as_macro_node return as_macro_node @property - @lru_cache(maxsize=1) def as_dataclass_node(self): from pyiron_workflow.nodes.transform import as_dataclass_node return as_dataclass_node -class HasCreator(ABC): +class HasCreator: """ A mixin class for creator (including both class-like and decorator). """ diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 59691503..2828ce7e 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -17,7 +17,7 @@ from pyiron_workflow.channels import Channel -class UsesState(ABC): +class UsesState: """ A mixin for any class using :meth:`__getstate__` or :meth:`__setstate__`. diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index 9ec01095..b704abc7 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -308,7 +308,7 @@ def _readiness_error_message(self) -> str: @staticmethod def _parse_executor( - executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict) + executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict), ) -> StdLibExecutor: """ If you've already got an executor, you're done. But if you get callable and diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 866f173b..3b86a5e4 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -518,7 +518,6 @@ def _before_run( self.inputs.fetch() if self.use_cache and self.cache_hit: # Read and use cache - if self.parent is None and emit_ran_signal: self.emit() elif self.parent is not None: diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index d358b259..b2b61979 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -306,7 +306,6 @@ def _collect_output_as_dataframe(self, iter_maps): self.dataframe.outputs.df.value_receiver = self.outputs.df for n, channel_map in enumerate(iter_maps): - row_collector = self._build_row_collector_node(n) for label, i in channel_map.items(): row_collector.inputs[label] = self.children[label][i] diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 7d1dcc0f..97befbbb 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -197,7 +197,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NOT_DATA]]: @staticmethod def hash_specification( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]] + input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], ): """For generating unique subclass names.""" diff --git a/pyiron_workflow/output_parser.py b/pyiron_workflow/output_parser.py index cdea18f3..c436c65a 100644 --- a/pyiron_workflow/output_parser.py +++ b/pyiron_workflow/output_parser.py @@ -5,7 +5,6 @@ import ast import inspect import re -from functools import lru_cache from textwrap import dedent @@ -57,7 +56,6 @@ def node_return(self): return None @property - @lru_cache(maxsize=1) def source(self): return self.dedented_source_string.split("\n")[:-1] diff --git a/pyiron_workflow/storage.py b/pyiron_workflow/storage.py index a3b14b37..679f8151 100644 --- a/pyiron_workflow/storage.py +++ b/pyiron_workflow/storage.py @@ -198,7 +198,6 @@ def _parse_filename(self, node: Node | None, filename: str | Path | None = None) class PickleStorage(StorageInterface): - _PICKLE = ".pckl" _CLOUDPICKLE = ".cpckl" diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index fbac705e..c60c9131 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -142,7 +142,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list def set_run_connections_according_to_linear_dag( - nodes: dict[str, Node] + nodes: dict[str, Node], ) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data @@ -194,7 +194,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]: def set_run_connections_according_to_dag( - nodes: dict[str, Node] + nodes: dict[str, Node], ) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 00ecb0fb..791e17c8 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -241,7 +241,6 @@ def _after_node_setup( autorun: bool = False, **kwargs, ): - for node in args: self.add_child(node) super()._after_node_setup( diff --git a/pyproject.toml b/pyproject.toml index d00a2102..b17a85b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,3 +84,6 @@ select = [ "I", ] ignore = ["E501"] #ignore line-length violations + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] # Ignore unused imports in init files -- we specify APIs this way diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py index 102f0479..1a40b4bc 100644 --- a/tests/benchmark/__init__.py +++ b/tests/benchmark/__init__.py @@ -1,3 +1,3 @@ """ Timed tests to make sure critical components stay sufficiently efficient. -""" \ No newline at end of file +""" diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py index 15b8e3ab..2c18a9cd 100644 --- a/tests/benchmark/test_benchmark.py +++ b/tests/benchmark/test_benchmark.py @@ -6,5 +6,5 @@ def test_nothing(self): self.assertTrue(True) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 85f3594e..7d4976e7 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1,3 +1,3 @@ """ Large and potentially slower tests to check how the pieces fit together. -""" \ No newline at end of file +""" diff --git a/tests/integration/test_output_injection.py b/tests/integration/test_output_injection.py index 5535a4a2..786fc0e6 100644 --- a/tests/integration/test_output_injection.py +++ b/tests/integration/test_output_injection.py @@ -8,12 +8,11 @@ class TestOutputInjection(unittest.TestCase): """ I.e. the process of inserting new nodes on-the-fly by modifying output channels" """ + def setUp(self) -> None: self.wf = Workflow("injection") self.int = Workflow.create.standard.UserInput(42, autorun=True) - self.list = Workflow.create.standard.UserInput( - list(range(10)), autorun=True - ) + self.list = Workflow.create.standard.UserInput(list(range(10)), autorun=True) def test_equality(self): with self.subTest("True expressions"): @@ -76,7 +75,7 @@ def test_algebra(self): (x // 43, 0 * x), ((x + 1) % x, x + 1 - x), (-x, -1 * x), - (+x, (-x)**2 / x), + (+x, (-x) ** 2 / x), (x, abs(-x)), ]: with self.subTest(f"{lhs.label} == {rhs.label}"): @@ -127,29 +126,24 @@ def test_casts(self): self.assertEqual(self.int.value, round(self.float).value) def test_access(self): - - self.dict = Workflow.create.standard.UserInput( - {"foo": 42}, autorun=True - ) + self.dict = Workflow.create.standard.UserInput({"foo": 42}, autorun=True) class Something: myattr = 1 - self.obj = Workflow.create.standard.UserInput( - Something(), autorun=True - ) + self.obj = Workflow.create.standard.UserInput(Something(), autorun=True) self.assertIsInstance(self.list[0].value, int) self.assertEqual(5, self.list[:5].len().value) self.assertEqual(4, self.list[1:5].len().value) self.assertEqual(3, self.list[-3:].len().value) self.assertEqual(2, self.list[1:5:2].len().value) - + self.assertEqual(42, self.dict["foo"].value) self.assertEqual(1, self.obj.myattr.value) def test_chaining(self): - self.assertFalse((self.list[:self.int//42][0] != 0).value) + self.assertFalse((self.list[: self.int // 42][0] != 0).value) def test_repeated_access_in_parent_scope(self): wf = Workflow("output_manipulation") @@ -162,13 +156,9 @@ def test_repeated_access_in_parent_scope(self): self.assertIs( a, b, - msg="The same operation should re-access an existing node in the parent" - ) - self.assertIsNot( - a, - c, - msg="Unique operations should yield unique nodes" + msg="The same operation should re-access an existing node in the parent", ) + self.assertIsNot(a, c, msg="Unique operations should yield unique nodes") def test_without_parent(self): d1 = self.list[5] @@ -179,14 +169,14 @@ def test_without_parent(self): d1, d2, msg="Outside the scope of a parent, we can't expect to re-access an " - "equivalent node" + "equivalent node", ) self.assertEqual( d1.label, d2.label, - msg="Equivalent operations should nonetheless generate equal labels" + msg="Equivalent operations should nonetheless generate equal labels", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/integration/test_parallel_speedup.py b/tests/integration/test_parallel_speedup.py index 3f937840..e3cbdefc 100644 --- a/tests/integration/test_parallel_speedup.py +++ b/tests/integration/test_parallel_speedup.py @@ -6,7 +6,6 @@ class TestSpeedup(unittest.TestCase): def test_speedup(self): - def make_workflow(label): wf = Workflow(label) wf.a = Workflow.create.standard.Sleep(t) @@ -36,7 +35,7 @@ def make_workflow(label): dt_cached_serial, 0.01 * t, msg="The cache should be trivially fast compared to actual execution of " - "a sleep node" + "a sleep node", ) wf = make_workflow("parallel") @@ -65,13 +64,13 @@ def make_workflow(label): 0.5 * dt_serial, msg=f"Expected the parallel solution to be at least 2x faster, but got" f"{dt_parallel} and {dt_serial} for parallel and serial times, " - f"respectively" + f"respectively", ) self.assertLess( dt_cached_parallel, 0.01 * t, msg="The cache should be trivially fast compared to actual execution of " - "a sleep node" + "a sleep node", ) def test_executor_instructions(self): @@ -95,9 +94,9 @@ def test_executor_instructions(self): 1.1 * t, msg="Expected the sleeps to run in parallel with minimal overhead (since " "it's just a thread pool executor) -- the advantage is that the " - "constructors should survive (de)serialization" + "constructors should survive (de)serialization", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/integration/test_provenance.py b/tests/integration/test_provenance.py index a763c694..48078d15 100644 --- a/tests/integration/test_provenance.py +++ b/tests/integration/test_provenance.py @@ -39,70 +39,65 @@ def test_executed_provenance(self): out = self.wf() self.assertDictEqual( - self.expected_post, - out, - msg="Sanity check that the graph is executing ok" + self.expected_post, out, msg="Sanity check that the graph is executing ok" ) self.assertListEqual( - ['time', 'prov', 'post'], + ["time", "prov", "post"], self.wf.provenance_by_execution, - msg="Even with a child running on an executor, provenance should log" + msg="Even with a child running on an executor, provenance should log", ) self.assertListEqual( self.wf.provenance_by_execution, self.wf.provenance_by_completion, - msg="The workflow itself is serial and these should be identical." + msg="The workflow itself is serial and these should be identical.", ) self.assertListEqual( - ['t', 'slow', 'fast', 'double'], + ["t", "slow", "fast", "double"], self.wf.prov.provenance_by_execution, msg="Later connections get priority over earlier connections, so we expect " - "the t-node to trigger 'slow' before 'fast'" + "the t-node to trigger 'slow' before 'fast'", ) self.assertListEqual( self.wf.prov.provenance_by_execution, self.wf.prov.provenance_by_completion, msg="The macro is running on an executor, but its children are in serial," - "so completion and execution order should be the same" + "so completion and execution order should be the same", ) def test_execution_vs_completion(self): - with ThreadPoolExecutor(max_workers=2) as exe: self.wf.prov.fast.executor = exe self.wf.prov.slow.executor = exe out = self.wf() self.assertDictEqual( - self.expected_post, - out, - msg="Sanity check that the graph is executing ok" + self.expected_post, out, msg="Sanity check that the graph is executing ok" ) self.assertListEqual( - ['t', 'slow', 'fast', 'double'], + ["t", "slow", "fast", "double"], self.wf.prov.provenance_by_execution, msg="Later connections get priority over earlier connections, so we expect " - "the t-node to trigger 'slow' before 'fast'" + "the t-node to trigger 'slow' before 'fast'", ) self.assertListEqual( - ['t', 'fast', 'slow', 'double'], + ["t", "fast", "slow", "double"], self.wf.prov.provenance_by_completion, msg="Since 'slow' is slow it shouldn't _finish_ until after 'fast' (but " - "still before 'double' since 'double' depends on 'slow')" + "still before 'double' since 'double' depends on 'slow')", ) self.assertListEqual( self.wf.provenance_by_execution, self.wf.provenance_by_completion, - msg="The workflow itself is serial and these should be identical." + msg="The workflow itself is serial and these should be identical.", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py index a4b858c0..579562da 100644 --- a/tests/integration/test_readme.py +++ b/tests/integration/test_readme.py @@ -16,5 +16,5 @@ def test_void(self): pass -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_transform.py b/tests/integration/test_transform.py index d3e243b2..fd0370a4 100644 --- a/tests/integration/test_transform.py +++ b/tests/integration/test_transform.py @@ -18,7 +18,7 @@ def test_list(self): self.assertListEqual( list(range(3)), out.outputs.to_list(), - msg="Expected behaviour here is an autoencoder" + msg="Expected behaviour here is an autoencoder", ) inp_class = inputs_to_list_factory(n) @@ -28,18 +28,16 @@ def test_list(self): inp_class, inp.__class__, msg="Regardless of origin, we expect to be constructing the exact same " - "class" + "class", ) self.assertIs(out_class, out.__class__) reloaded = pickle.loads(pickle.dumps(out)) self.assertEqual( - out.label, - reloaded.label, - msg="Transformers should be pickleable" + out.label, reloaded.label, msg="Transformers should be pickleable" ) self.assertDictEqual( out.outputs.to_value_dict(), reloaded.outputs.to_value_dict(), - msg="Transformers should be pickleable" + msg="Transformers should be pickleable", ) diff --git a/tests/integration/test_while.py b/tests/integration/test_while.py index 0462560e..927e0902 100644 --- a/tests/integration/test_while.py +++ b/tests/integration/test_while.py @@ -44,37 +44,31 @@ class TestWhileLoop(unittest.TestCase): def test_while_loop(self): a, b, cap = 0, 2, 5 n = AddWhileLessThan(a, b, cap, autorun=True) - self.assertGreaterEqual( - 6, - n.outputs.greater.value, - msg="Verify output" - ) + self.assertGreaterEqual(6, n.outputs.greater.value, msg="Verify output") self.assertListEqual( - [2, 4, 6], - n.history.outputs.list.value, - msg="Verify loop history logging" + [2, 4, 6], n.history.outputs.list.value, msg="Verify loop history logging" ) self.assertListEqual( [ - 'body', - 'history', - 'condition', - 'switch', - 'body', - 'history', - 'condition', - 'switch', - 'body', - 'history', - 'condition', - 'switch' + "body", + "history", + "condition", + "switch", + "body", + "history", + "condition", + "switch", + "body", + "history", + "condition", + "switch", ], n.provenance_by_execution, - msg="Verify execution order -- the same nodes get run repeatedly in acyclic" + msg="Verify execution order -- the same nodes get run repeatedly in acyclic", ) reloaded = pickle.loads(pickle.dumps(n)) self.assertListEqual( reloaded.history.outputs.list.value, n.history.outputs.list.value, - msg="Should be able to save and re-load cyclic graphs just like usual" + msg="Should be able to save and re-load cyclic graphs just like usual", ) diff --git a/tests/integration/test_workflow.py b/tests/integration/test_workflow.py index 14869ffd..fe6ed412 100644 --- a/tests/integration/test_workflow.py +++ b/tests/integration/test_workflow.py @@ -102,20 +102,18 @@ def sqrt(value=0): ) def test_for_loop(self): - base = 42 to_add = list(range(5)) bulk_loop = Workflow.create.for_node( demo_nodes.OptionallyAdd, iter_on="y", x=base, # Broadcast - y=to_add # Scattered + y=to_add, # Scattered ) out = bulk_loop() for output, expectation in zip( - out.df["sum"].values.tolist(), - [base + v for v in to_add], strict=False + out.df["sum"].values.tolist(), [base + v for v in to_add], strict=False ): self.assertAlmostEqual( output, @@ -151,7 +149,7 @@ def test_executors(self): Workflow.create.ProcessPoolExecutor, Workflow.create.ThreadPoolExecutor, Workflow.create.CloudpickleProcessPoolExecutor, - Workflow.create.ExecutorlibExecutor + Workflow.create.ExecutorlibExecutor, ] wf = Workflow("executed") @@ -168,18 +166,20 @@ def test_executors(self): self.assertDictEqual(reference_output, reloaded.outputs.to_value_dict()) for exe_cls in executors: - with self.subTest( - f"{exe_cls.__module__}.{exe_cls.__qualname__} entire workflow" - ), exe_cls() as exe: + with ( + self.subTest( + f"{exe_cls.__module__}.{exe_cls.__qualname__} entire workflow" + ), + exe_cls() as exe, + ): wf.executor = exe self.assertDictEqual( - reference_output, - wf().result().outputs.to_value_dict() + reference_output, wf().result().outputs.to_value_dict() ) self.assertFalse( wf.running, msg="The workflow should stop. For thread pool this required a " - "little sleep" + "little sleep", ) wf.executor = None @@ -193,7 +193,7 @@ def test_executors(self): any(n.running for n in wf), msg=f"All children should be done running -- for thread pools this " f"requires a very short sleep -- got " - f"{[(n.label, n.running) for n in wf]}" + f"{[(n.label, n.running) for n in wf]}", ) for child in wf: child.executor = None @@ -219,7 +219,7 @@ def Sleep(t): second_out, msg="Even thought the _input_ hasn't changed, we expect to avoid the first " "(cached) result by virtue of resetting the cache when the body of " - "the composite graph has changed" + "the composite graph has changed", ) t0 = time.perf_counter() @@ -228,13 +228,13 @@ def Sleep(t): self.assertEqual( third_out, second_out, - msg="This time there is no change and we expect the cached result" + msg="This time there is no change and we expect the cached result", ) self.assertLess( dt, 0.1 * wf.c.inputs.t.value, msg="And because it used the cache we expect it much faster than the sleep " - "time" + "time", ) def test_failure(self): @@ -259,7 +259,10 @@ def test_failure(self): wf.starting_nodes = [wf.a] wf.automate_execution = False - with self.subTest("Check completion"), Workflow.create.ProcessPoolExecutor() as exe: + with ( + self.subTest("Check completion"), + Workflow.create.ProcessPoolExecutor() as exe, + ): wf.c_fails.executor = exe wf(raise_run_exceptions=False) @@ -270,7 +273,7 @@ def test_failure(self): (wf.d_if_success.outputs.user_input.value, NOT_DATA), # Never ran ( wf.d_if_failure.outputs.user_input.value, - wf.d_if_failure.inputs.user_input.value + wf.d_if_failure.inputs.user_input.value, ), (wf.e_fails.outputs.add.value, NOT_DATA), ]: @@ -297,12 +300,12 @@ def test_failure(self): self.assertIn( wf.c_fails.run.full_label, str(e), - msg="Failed node should be identified" + msg="Failed node should be identified", ) self.assertIn( wf.e_fails.run.full_label, str(e), - msg="Indeed, _both_ failed nodes should be identified" + msg="Indeed, _both_ failed nodes should be identified", ) with self.subTest("Check recovery file"): @@ -311,7 +314,7 @@ def test_failure(self): filename=wf.as_path().joinpath("recovery") ), msg="Expect a recovery file to be written for the parent-most" - "object when a child fails" + "object when a child fails", ) finally: wf.delete_storage() @@ -322,9 +325,9 @@ def test_failure(self): f"written a recovery file, so after removing that the whole " f"node directory for the workflow should be cleaned up." f"Instead, {wf.as_path()} exists and has content " - f"{[f for f in wf.as_path().iterdir()] if wf.as_path().is_dir() else None}" + f"{[f for f in wf.as_path().iterdir()] if wf.as_path().is_dir() else None}", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/static/demo_nodes.py b/tests/static/demo_nodes.py index ca180dea..482848d8 100644 --- a/tests/static/demo_nodes.py +++ b/tests/static/demo_nodes.py @@ -2,7 +2,6 @@ A demo node package for the purpose of testing. """ - from pyiron_workflow import Workflow diff --git a/tests/static/docs_submodule/__init__.py b/tests/static/docs_submodule/__init__.py index 254ba920..09e6da67 100644 --- a/tests/static/docs_submodule/__init__.py +++ b/tests/static/docs_submodule/__init__.py @@ -3,4 +3,4 @@ >>> print("This is an example") This is an example -""" \ No newline at end of file +""" diff --git a/tests/static/docs_submodule/bad_class.py b/tests/static/docs_submodule/bad_class.py index a600c25d..3fce38bc 100644 --- a/tests/static/docs_submodule/bad_class.py +++ b/tests/static/docs_submodule/bad_class.py @@ -4,4 +4,4 @@ class Documented: >>> print(42) This is not the expected output - """ \ No newline at end of file + """ diff --git a/tests/static/docs_submodule/bad_init_example/__init__.py b/tests/static/docs_submodule/bad_init_example/__init__.py index 4ebcff71..3974cfcb 100644 --- a/tests/static/docs_submodule/bad_init_example/__init__.py +++ b/tests/static/docs_submodule/bad_init_example/__init__.py @@ -2,4 +2,4 @@ Let's test just __init__ >>> 0/1 -""" \ No newline at end of file +""" diff --git a/tests/static/docs_submodule/good_function.py b/tests/static/docs_submodule/good_function.py index f9340d49..eba2a99f 100644 --- a/tests/static/docs_submodule/good_function.py +++ b/tests/static/docs_submodule/good_function.py @@ -1,8 +1,8 @@ def function(): """ Here is an A-OK docstring - + >>> print(42) 42 """ - return None \ No newline at end of file + return None diff --git a/tests/static/docs_submodule/mix.py b/tests/static/docs_submodule/mix.py index 1884c493..de4ee4f6 100644 --- a/tests/static/docs_submodule/mix.py +++ b/tests/static/docs_submodule/mix.py @@ -13,6 +13,7 @@ def bad(): """ return + def error(): """ >>> 1/0 @@ -40,4 +41,4 @@ def bad(self): def error(self): """ >>> 1/0 - """ \ No newline at end of file + """ diff --git a/tests/static/nodes_subpackage/demo_nodes.py b/tests/static/nodes_subpackage/demo_nodes.py index e039ab7a..487dd220 100644 --- a/tests/static/nodes_subpackage/demo_nodes.py +++ b/tests/static/nodes_subpackage/demo_nodes.py @@ -2,7 +2,6 @@ A demo node package for the purpose of testing. """ - from pyiron_workflow import Workflow diff --git a/tests/static/nodes_subpackage/subsub_package/demo_nodes.py b/tests/static/nodes_subpackage/subsub_package/demo_nodes.py index e039ab7a..487dd220 100644 --- a/tests/static/nodes_subpackage/subsub_package/demo_nodes.py +++ b/tests/static/nodes_subpackage/subsub_package/demo_nodes.py @@ -2,7 +2,6 @@ A demo node package for the purpose of testing. """ - from pyiron_workflow import Workflow diff --git a/tests/static/nodes_subpackage/subsub_sibling/demo_nodes.py b/tests/static/nodes_subpackage/subsub_sibling/demo_nodes.py index e039ab7a..487dd220 100644 --- a/tests/static/nodes_subpackage/subsub_sibling/demo_nodes.py +++ b/tests/static/nodes_subpackage/subsub_sibling/demo_nodes.py @@ -2,7 +2,6 @@ A demo node package for the purpose of testing. """ - from pyiron_workflow import Workflow diff --git a/tests/unit/executors/test_cloudprocesspool.py b/tests/unit/executors/test_cloudprocesspool.py index c322207d..8e70bd1e 100644 --- a/tests/unit/executors/test_cloudprocesspool.py +++ b/tests/unit/executors/test_cloudprocesspool.py @@ -12,6 +12,7 @@ class Foo: """ A base class to be dynamically modified for testing CloudpickleProcessPoolExecutor. """ + def __init__(self, fnc: callable): self.fnc = fnc self.result = None @@ -31,23 +32,18 @@ def dynamic_foo(): Overrides the `fnc` input of `Foo` with the decorated function. """ + def as_dynamic_foo(fnc: callable): return type( "DynamicFoo", (Foo,), # Define parentage - { - "__init__": partialmethod( - Foo.__init__, - fnc - ) - }, + {"__init__": partialmethod(Foo.__init__, fnc)}, ) return as_dynamic_foo class TestCloudpickleProcessPoolExecutor(unittest.TestCase): - def test_unpickleable_callable(self): """ We should be able to use an unpickleable callable -- in this case, a method of @@ -62,13 +58,10 @@ def slowly_returns_42(): dynamic_42 = slowly_returns_42() # Instantiate the dynamically defined class self.assertIsInstance( - dynamic_42, - Foo, - msg="Just a sanity check that the test is set up right" + dynamic_42, Foo, msg="Just a sanity check that the test is set up right" ) self.assertIsNone( - dynamic_42.result, - msg="Just a sanity check that the test is set up right" + dynamic_42.result, msg="Just a sanity check that the test is set up right" ) executor = CloudpickleProcessPoolExecutor() fs = executor.submit(dynamic_42.run) @@ -103,7 +96,7 @@ def slowly_returns_unpickleable(): self.assertIsInstance( fs.result(timeout=120), Foo, - msg="The custom future should be unpickling the result" + msg="The custom future should be unpickling the result", ) self.assertEqual(fs.result(timeout=120).result, "it was an inside job!") @@ -157,7 +150,7 @@ def slow(): self.assertEqual( fs.result(timeout=60), fortytwo, - msg="waiting long enough should get the result" + msg="waiting long enough should get the result", ) with self.assertRaises(TimeoutError): @@ -165,5 +158,5 @@ def slow(): fs.result(timeout=0.0001) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/mixin/test_preview.py b/tests/unit/mixin/test_preview.py index 26810f30..7983155d 100644 --- a/tests/unit/mixin/test_preview.py +++ b/tests/unit/mixin/test_preview.py @@ -30,7 +30,7 @@ def scraper_factory( "__module__": io_defining_function.__module__, "_output_labels": None if len(output_labels) == 0 else output_labels, "_validate_output_labels": validate_output_labels, - "_io_defining_function_uses_self": io_defining_function_uses_self + "_io_defining_function_uses_self": io_defining_function_uses_self, }, {}, ) @@ -49,6 +49,7 @@ def scraper_decorator(fnc): factory_made._reduce_imports_as = (fnc.__module__, fnc.__qualname__) factory_made.preview_io() return factory_made + return scraper_decorator @@ -69,20 +70,23 @@ def Mixed(x, y: int = 42): {"x": (None, NOT_DATA), "y": (int, 42)}, Mixed.preview_inputs(), msg="Input specifications should be available at the class level, with or " - "without type hints and/or defaults provided." + "without type hints and/or defaults provided.", ) - with self.subTest("Protected"), self.assertRaises( - ValueError, - msg="Inputs must not overlap with __init__ signature terms" + with ( + self.subTest("Protected"), + self.assertRaises( + ValueError, msg="Inputs must not overlap with __init__ signature terms" + ), ): + @as_scraper() def Selfish(self, x): return x def test_preview_outputs(self): - with self.subTest("Plain"): + @as_scraper() def Return(x): return x @@ -90,10 +94,11 @@ def Return(x): self.assertDictEqual( {"x": None}, Return.preview_outputs(), - msg="Should parse without label or hint." + msg="Should parse without label or hint.", ) with self.subTest("Labeled"): + @as_scraper("y") def LabeledReturn(x) -> None: return x @@ -101,14 +106,15 @@ def LabeledReturn(x) -> None: self.assertDictEqual( {"y": type(None)}, LabeledReturn.preview_outputs(), - msg="Should parse with label and hint." + msg="Should parse with label and hint.", ) with self.subTest("Hint-return count mismatch"): with self.assertRaises( ValueError, - msg="Should fail when scraping incommensurate hints and returns" + msg="Should fail when scraping incommensurate hints and returns", ): + @as_scraper() def HintMismatchesScraped(x) -> int: y, z = 5.0, 5 @@ -116,8 +122,9 @@ def HintMismatchesScraped(x) -> int: with self.assertRaises( ValueError, - msg="Should fail when provided labels are incommensurate with hints" + msg="Should fail when provided labels are incommensurate with hints", ): + @as_scraper("xo", "yo", "zo") def HintMismatchesProvided(x) -> int: y, z = 5.0, 5 @@ -127,8 +134,9 @@ def HintMismatchesProvided(x) -> int: with self.assertRaises( ValueError, msg="The nuber of labels -- if explicitly provided -- must be commensurate " - "with the number of returned items" + "with the number of returned items", ): + @as_scraper("xo", "yo") def LabelsMismatchScraped(x) -> tuple[int, float]: _y, _z = 5.0, 5 @@ -142,14 +150,15 @@ def IgnoreScraping(x) -> tuple[int, float]: self.assertDictEqual( {"x0": int, "x1": float}, IgnoreScraping.preview_outputs(), - msg="Returned tuples can be received by force" + msg="Returned tuples can be received by force", ) with self.subTest("Multiple returns"): with self.assertRaises( ValueError, - msg="Branched returns cannot be scraped and will fail on validation" + msg="Branched returns cannot be scraped and will fail on validation", ): + @as_scraper("truth") def Branched(x) -> bool: if x <= 0: # noqa: SIM103 @@ -160,13 +169,15 @@ def Branched(x) -> bool: @as_scraper("truth", validate_output_labels=False) def Branched(x) -> bool: return not x <= 0 + self.assertDictEqual( {"truth": bool}, Branched.preview_outputs(), - msg="We can force-override this at our own risk." + msg="We can force-override this at our own risk.", ) with self.subTest("Uninspectable function"): + def _uninspectable(): template = dedent(""" def __source_code_not_available(x): @@ -180,7 +191,7 @@ def __source_code_not_available(x): with self.assertRaises( OSError, msg="If the source code cannot be inspected for output labels, they " - "_must_ be provided." + "_must_ be provided.", ): as_scraper()(f) @@ -190,5 +201,5 @@ def __source_code_not_available(x): self.assertIn( f"WARNING:{logger.name}:" + no_output_validation_warning(new_cls), log.output, - msg="Verify that the expected warning appears in the log" + msg="Verify that the expected warning appears in the log", ) diff --git a/tests/unit/mixin/test_run.py b/tests/unit/mixin/test_run.py index 00d746fb..0af18380 100644 --- a/tests/unit/mixin/test_run.py +++ b/tests/unit/mixin/test_run.py @@ -53,7 +53,7 @@ def test_runnable_not_ready(self): self.assertTrue( runnable.ready, - msg="Freshly instantiated, it is neither running nor failed!" + msg="Freshly instantiated, it is neither running nor failed!", ) with self.subTest("Running"): @@ -83,7 +83,7 @@ def test_runnable_not_ready(self): result, msg="We should be able to bypass the readiness check with a flag, and " "in this simple case expect to get perfectly normal behaviour " - "afterwards" + "afterwards", ) def test_failure(self): @@ -92,8 +92,7 @@ def test_failure(self): with self.assertRaises(RuntimeError): runnable.run() self.assertTrue( - runnable.failed, - msg="Encountering an error should set status to failed" + runnable.failed, msg="Encountering an error should set status to failed" ) runnable.failed = False @@ -101,7 +100,7 @@ def test_failure(self): self.assertTrue( runnable.failed, msg="We should be able to stop the exception from getting raised, but the " - "status should still be failed" + "status should still be failed", ) def test_runnable_run_local(self): @@ -109,18 +108,15 @@ def test_runnable_run_local(self): result = runnable.run() self.assertIsNone( - runnable.future, - msg="Without an executor, we expect no future" + runnable.future, msg="Without an executor, we expect no future" ) self.assertDictEqual( - runnable.expected_run_output, - result, - msg="Expected the result" + runnable.expected_run_output, result, msg="Expected the result" ) self.assertDictEqual( runnable.expected_processed_value, runnable.processed, - msg="Expected the result, including post-processing 'bar' value" + msg="Expected the result, including post-processing 'bar' value", ) def test_runnable_run_with_executor(self): @@ -142,29 +138,26 @@ def maybe_get_executor(get_executor): result = runnable.run() self.assertIsInstance( - result, - Future, - msg="With an executor, a future should be returned" + result, Future, msg="With an executor, a future should be returned" ) self.assertIs( result, runnable.future, - msg="With an executor, the future attribute should get populated" + msg="With an executor, the future attribute should get populated", ) self.assertDictEqual( runnable.expected_run_output, result.result(timeout=30), - msg="Expected the result (after waiting for it to compute, of course)" + msg="Expected the result (after waiting for it to compute, of course)", ) self.assertDictEqual( runnable.expected_processed_value, runnable.processed, - msg="Expected the result, including post-processing 'bar' value" + msg="Expected the result, including post-processing 'bar' value", ) with self.assertRaises( - NotImplementedError, - msg="That's not an executor at all" + NotImplementedError, msg="That's not an executor at all" ): runnable.executor = 42 runnable.run() @@ -172,11 +165,11 @@ def maybe_get_executor(get_executor): with self.assertRaises( TypeError, msg="Callables are ok, but if they don't return an executor we should get " - "and error." + "and error.", ): runnable.executor = (maybe_get_executor, (False,), {}) runnable.run() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index 72333230..103b41dd 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -23,14 +23,14 @@ def test_getattr(self): self.assertIn( "Did you mean 'middle_sub' and not 'Middle_sub'", str(context.exception), - msg="middle_sub must be suggested as it is close to Middle_sub" + msg="middle_sub must be suggested as it is close to Middle_sub", ) with self.assertRaises(AttributeError) as context: _ = self.middle1.my_neighbor_stinks self.assertNotIn( "Did you mean", str(context.exception), - msg="Nothings should be suggested for my_neighbor_stinks" + msg="Nothings should be suggested for my_neighbor_stinks", ) def test_label_validity(self): @@ -39,8 +39,7 @@ def test_label_validity(self): def test_label_delimiter(self): with self.assertRaises( - ValueError, - msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" + ValueError, msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" ): Semantic(f"invalid{Semantic.semantic_delimiter}label") @@ -51,7 +50,7 @@ def test_semantic_delimiter(self): msg="This is just a hard-code to the current value, update it freely so " "the test passes; if it fails it's just a reminder that your change is " "not backwards compatible, and the next release number should reflect " - "this." + "this.", ) def test_parent(self): @@ -61,14 +60,12 @@ def test_parent(self): with self.subTest(f"{ParentMost.__name__} exceptions"): with self.assertRaises( - TypeError, - msg=f"{ParentMost.__name__} instances can't have parent" + TypeError, msg=f"{ParentMost.__name__} instances can't have parent" ): self.root.parent = SemanticParent(label="foo") with self.assertRaises( - TypeError, - msg=f"{ParentMost.__name__} instances can't be children" + TypeError, msg=f"{ParentMost.__name__} instances can't be children" ): some_parent = SemanticParent(label="bar") some_parent.add_child(self.root) @@ -96,59 +93,56 @@ def test_root(self): def test_as_path(self): self.assertEqual( - self.root.as_path(), - Path.cwd() / self.root.label, - msg="Default None root" + self.root.as_path(), Path.cwd() / self.root.label, msg="Default None root" ) self.assertEqual( self.child1.as_path(root=".."), Path("..") / self.root.label / self.child1.label, - msg="String root" + msg="String root", ) self.assertEqual( self.middle2.as_path(root=Path("..", "..")), ( - Path("..", "..") / - self.root.label / - self.middle1.label / - self.middle2.label + Path("..", "..") + / self.root.label + / self.middle1.label + / self.middle2.label ), - msg="Path root" + msg="Path root", ) def test_detached_parent_path(self): orphan = Semantic("orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( - orphan.parent, - msg="We still should not explicitly have a parent" + orphan.parent, msg="We still should not explicitly have a parent" ) self.assertListEqual( orphan.detached_parent_path.split(orphan.semantic_delimiter), self.child2.semantic_path.split(orphan.semantic_delimiter)[:-1], msg="Despite not having a parent, the detached path should store semantic " - "path info through the get/set state routine" + "path info through the get/set state routine", ) self.assertEqual( orphan.semantic_path, self.child2.semantic_path, msg="The detached path should carry through to semantic path in the " - "absence of a parent" + "absence of a parent", ) orphan.label = "orphan" # Re-set label after getting state orphan.parent = self.child2.parent self.assertIsNone( orphan.detached_parent_path, msg="Detached paths aren't necessary and shouldn't co-exist with the " - "presence of a parent" + "presence of a parent", ) self.assertListEqual( orphan.semantic_path.split(orphan.semantic_delimiter)[:-1], self.child2.semantic_path.split(self.child2.semantic_delimiter)[:-1], msg="Sanity check -- except for the now-different labels, we should be " - "recovering the usual semantic path on setting a parent." + "recovering the usual semantic path on setting a parent.", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nodes/test_composite.py b/tests/unit/nodes/test_composite.py index 8560f8d7..3569dedb 100644 --- a/tests/unit/nodes/test_composite.py +++ b/tests/unit/nodes/test_composite.py @@ -46,7 +46,6 @@ def outputs(self) -> OutputsWithInjection: class TestComposite(unittest.TestCase): - def setUp(self) -> None: self.comp = AComposite("my_composite") super().setUp() @@ -59,11 +58,11 @@ def foo(x: int = 0) -> int: from_class = foo() self.assertEqual(from_class.run(), 1, msg="Node should be fully functioning") self.assertIsNone( - from_class.parent, - msg="Wrapping from the class should give no parent" + from_class.parent, msg="Wrapping from the class should give no parent" ) comp = self.comp + @comp.wrap.as_function_node("y") def bar(x: int = 0) -> int: return x + 2 @@ -73,51 +72,47 @@ def bar(x: int = 0) -> int: self.assertIsNone( from_instance.parent, msg="Wrappers are not creators, wrapping from the instance makes no " - "difference" + "difference", ) def test_node_addition(self): # Validate the four ways to add a node self.comp.add_child(Composite.create.function_node(plus_one, label="foo")) - self.comp.baz = self.comp.create.function_node(plus_one, label="whatever_baz_gets_used") + self.comp.baz = self.comp.create.function_node( + plus_one, label="whatever_baz_gets_used" + ) Composite.create.function_node(plus_one, label="qux", parent=self.comp) self.assertListEqual( list(self.comp.children.keys()), ["foo", "baz", "qux"], - msg="Expected every above syntax to add a node OK" + msg="Expected every above syntax to add a node OK", ) print(self.comp.children) self.comp.boa = self.comp.qux self.assertListEqual( list(self.comp.children.keys()), ["foo", "baz", "boa"], - msg="Reassignment should remove the original instance" + msg="Reassignment should remove the original instance", ) - + def test_node_access(self): node = Composite.create.function_node(plus_one) self.comp.child = node self.assertIs( - self.comp.child, - node, - msg="Access should be possible by attribute" - ) - self.assertIs( - self.comp["child"], - node, - msg="Access should be possible by item" + self.comp.child, node, msg="Access should be possible by attribute" ) + self.assertIs(self.comp["child"], node, msg="Access should be possible by item") self.assertIs( self.comp.children["child"], node, - msg="Access should be possible by item on children collection" + msg="Access should be possible by item on children collection", ) for n in self.comp: self.assertIs( node, n, - msg="Should be able to iterate through (the one and only) nodes" + msg="Should be able to iterate through (the one and only) nodes", ) with self.assertRaises( @@ -125,7 +120,7 @@ def test_node_access(self): msg="Composites should override the attribute access portion of their " "`HasIOWithInjection` mixin to guarantee that attribute access is " "always looking for children. If attribute access is actually desired, " - " it can be accomplished with a `GetAttr` node." + " it can be accomplished with a `GetAttr` node.", ): self.comp.not_a_child_or_attribute # noqa: B018 @@ -144,12 +139,12 @@ def test_node_removal(self): self.assertListEqual( [(node.inputs.x, self.comp.owned.outputs.y)], disconnected, - msg="Removal should return destroyed connections" + msg="Removal should return destroyed connections", ) self.assertListEqual( self.comp.starting_nodes, [], - msg="Removal should also remove from starting nodes" + msg="Removal should also remove from starting nodes", ) node_owned = self.comp.owned @@ -157,12 +152,10 @@ def test_node_removal(self): self.assertEqual( node_owned.parent, None, - msg="Should be able to remove nodes by label as well as by object" + msg="Should be able to remove nodes by label as well as by object", ) self.assertListEqual( - [], - disconnections, - msg="node1 should have no connections left" + [], disconnections, msg="node1 should have no connections left" ) def test_label_uniqueness(self): @@ -176,7 +169,7 @@ def test_label_uniqueness(self): with self.assertRaises( AttributeError, msg="The provided label is ok, but then assigning to baz should give " - "trouble since that name is already occupied" + "trouble since that name is already occupied", ): self.comp.foo = Composite.create.function_node(plus_one, label="whatever") @@ -207,13 +200,13 @@ def test_label_uniqueness(self): self.assertEqual( 2, len(self.comp), - msg="Without strict naming, we should be able to add to an existing name" + msg="Without strict naming, we should be able to add to an existing name", ) self.assertListEqual( ["foo", "foo0"], list(self.comp.children.keys()), msg="When adding a node with an existing name and relaxed naming, the new " - "node should get an index on its label so each label is still unique" + "node should get an index on its label so each label is still unique", ) def test_singular_ownership(self): @@ -230,9 +223,7 @@ def test_singular_ownership(self): comp1.remove_child(node2) comp2.add_child(node2) self.assertEqual( - node2.parent, - comp2, - msg="Freed nodes should be able to join other parents" + node2.parent, comp2, msg="Freed nodes should be able to join other parents" ) def test_replace(self): @@ -271,7 +262,7 @@ def different_output_channel(x: int = 0) -> int: self.comp.replace_child(n1, replacement) out = self.comp.run(n1__x=0) self.assertEqual( - (0+2) + 1 + 1, out.n3__y, msg="Should be able to replace by instance" + (0 + 2) + 1 + 1, out.n3__y, msg="Should be able to replace by instance" ) self.assertEqual( 0 - 2, out.n1__minus, msg="Replacement output should also appear" @@ -302,16 +293,16 @@ def different_output_channel(x: int = 0) -> int: self.comp.n3, replacement, msg="Sanity check -- when replacing with class, a _new_ instance " - "should be created" + "should be created", ) self.comp.replace_child(self.comp.n3, n3) self.comp.n1 = x_plus_minus_z self.assertEqual( - (0+2) + 1 + 1, + (0 + 2) + 1 + 1, self.comp.run(n1__x=0).n3__y, msg="Assigning a new _class_ to an existing node should be a shortcut " - "for replacement" + "for replacement", ) self.comp.replace_child(self.comp.n1, n1) # Return to original state @@ -320,7 +311,7 @@ def different_output_channel(x: int = 0) -> int: (0 + 10) + 1 + 1, self.comp.run(n1__z=0).n3__y, msg="Different IO should be compatible as long as what's missing is " - "not connected" + "not connected", ) self.comp.replace_child(self.comp.n1, n1) @@ -329,7 +320,7 @@ def different_output_channel(x: int = 0) -> int: (0 + 1) + 1 + 100, self.comp.run(n1__x=0).n3__z, msg="Different IO should be compatible as long as what's missing is " - "not connected" + "not connected", ) self.comp.replace_child(self.comp.n3, n3) @@ -340,23 +331,20 @@ def different_output_channel(x: int = 0) -> int: another_node = x_plus_minus_z(parent=another_comp) with self.assertRaises( - ValueError, - msg="Should fail when replacement has a parent" + ValueError, msg="Should fail when replacement has a parent" ): self.comp.replace_child(self.comp.n1, another_node) another_comp.remove_child(another_node) another_node.inputs.x = replacement.outputs.y with self.assertRaises( - ValueError, - msg="Should fail when replacement is connected" + ValueError, msg="Should fail when replacement is connected" ): self.comp.replace_child(self.comp.n1, another_node) another_node.disconnect() with self.assertRaises( - ValueError, - msg="Should fail if the node being replaced isn't a child" + ValueError, msg="Should fail if the node being replaced isn't a child" ): self.comp.replace_child(replacement, another_node) @@ -365,20 +353,19 @@ def wrong_hint(x: float = 0) -> float: return x + 1.1 with self.assertRaises( - TypeError, - msg="Should not be able to replace with the wrong type hints" + TypeError, msg="Should not be able to replace with the wrong type hints" ): self.comp.n1 = wrong_hint with self.assertRaises( ConnectionCopyError, - msg="Should not be able to replace with any missing connected channels" + msg="Should not be able to replace with any missing connected channels", ): self.comp.n2 = different_input_channel with self.assertRaises( ConnectionCopyError, - msg="Should not be able to replace with any missing connected channels" + msg="Should not be able to replace with any missing connected channels", ): self.comp.n2 = different_output_channel @@ -386,7 +373,7 @@ def wrong_hint(x: float = 0) -> float: 3, self.comp.run().n3__y, msg="Failed replacements should always restore the original state " - "cleanly" + "cleanly", ) def test_length(self): @@ -396,7 +383,7 @@ def test_length(self): self.assertEqual( l1 + 1, len(self.comp), - msg="Expected length to count the number of children" + msg="Expected length to count the number of children", ) def test_run(self): @@ -410,12 +397,12 @@ def test_run(self): self.assertEqual( 2, self.comp.n2.outputs.y.value, - msg="Expected to start from starting node and propagate" + msg="Expected to start from starting node and propagate", ) self.assertIs( NOT_DATA, self.comp.n3.outputs.y.value, - msg="n3 was omitted from the execution diagram, it should not have run" + msg="n3 was omitted from the execution diagram, it should not have run", ) def test_set_run_signals_to_dag(self): @@ -426,25 +413,19 @@ def test_set_run_signals_to_dag(self): self.comp.set_run_signals_to_dag_execution() self.comp.run() self.assertEqual( - 1, - self.comp.n1.outputs.y.value, - msg="Expected all nodes to run" + 1, self.comp.n1.outputs.y.value, msg="Expected all nodes to run" ) self.assertEqual( - 2, - self.comp.n2.outputs.y.value, - msg="Expected all nodes to run" + 2, self.comp.n2.outputs.y.value, msg="Expected all nodes to run" ) self.assertEqual( - 43, - self.comp.n3.outputs.y.value, - msg="Expected all nodes to run" + 43, self.comp.n3.outputs.y.value, msg="Expected all nodes to run" ) self.comp.n1.inputs.x = self.comp.n2 with self.assertRaises( CircularDataFlowError, - msg="Should not be able to automate graphs with circular data" + msg="Should not be able to automate graphs with circular data", ): self.comp.set_run_signals_to_dag_execution() @@ -460,7 +441,7 @@ def test_return(self): 1, self.comp.outputs.n1__y.value, msg="Sanity check that the output has been filled and is stored under the " - "name we think it is" + "name we think it is", ) # Make sure the returned object is functionally a dot-dict self.assertEqual(1, out["n1__y"], msg="Should work with item-access") @@ -469,12 +450,12 @@ def test_return(self): self.assertIs( not_dottable_name_node, self.comp.children[not_dottable_string], - msg="Should be able to access the node by item" + msg="Should be able to access the node by item", ) self.assertEqual( 43, out[not_dottable_string + "__y"], - msg="Should always be able to fall back to item access with crazy labels" + msg="Should always be able to fall back to item access with crazy labels", ) def test_de_activate_strict_connections(self): @@ -482,17 +463,17 @@ def test_de_activate_strict_connections(self): self.comp.sub_comp.n1 = Composite.create.function_node(plus_one, x=0) self.assertTrue( self.comp.sub_comp.n1.inputs.x.strict_hints, - msg="Sanity check that test starts in the expected condition" + msg="Sanity check that test starts in the expected condition", ) self.comp.deactivate_strict_hints() self.assertFalse( self.comp.sub_comp.n1.inputs.x.strict_hints, - msg="Deactivating should propagate to children" + msg="Deactivating should propagate to children", ) self.comp.activate_strict_hints() self.assertTrue( self.comp.sub_comp.n1.inputs.x.strict_hints, - msg="Activating should propagate to children" + msg="Activating should propagate to children", ) def test_graph_info(self): @@ -505,57 +486,56 @@ def test_graph_info(self): self.assertEqual( top.semantic_delimiter + top.label, top.graph_path, - msg="The parent-most node should be its own path." + msg="The parent-most node should be its own path.", ) self.assertTrue( top.middle_composite.graph_path.startswith(top.graph_path), - msg="The path should go to the parent-most object." + msg="The path should go to the parent-most object.", ) self.assertTrue( top.middle_function.graph_path.startswith(top.graph_path), - msg="The path should go to the parent-most object." + msg="The path should go to the parent-most object.", ) self.assertTrue( top.middle_composite.deep_node.graph_path.startswith(top.graph_path), msg="The path should go to the parent-most object, recursively from " - "all depths." + "all depths.", ) with self.subTest("test_graph_root"): self.assertIs( top, top.graph_root, - msg="The parent-most node should be its own graph_root." + msg="The parent-most node should be its own graph_root.", ) self.assertIs( top, top.middle_composite.graph_root, - msg="The parent-most node should be the graph_root." + msg="The parent-most node should be the graph_root.", ) self.assertIs( top, top.middle_function.graph_root, - msg="The parent-most node should be the graph_root." + msg="The parent-most node should be the graph_root.", ) self.assertIs( top, top.middle_composite.deep_node.graph_root, msg="The parent-most node should be the graph_root, recursively accessible " - "from all depths." + "from all depths.", ) def test_import_ready(self): - totally_findable = demo_nodes.OptionallyAdd() self.assertTrue( totally_findable.import_ready, - msg="The node class is well defined and in an importable module" + msg="The node class is well defined and in an importable module", ) bad_class = demo_nodes.Dynamic() self.assertFalse( bad_class.import_ready, msg="The node is in an importable location, but the imported object is not " - "the node class (but rather the node function)" + "the node class (but rather the node function)", ) with self.subTest(msg="Made up module"): og_module = totally_findable.__class__.__module__ @@ -564,7 +544,7 @@ def test_import_ready(self): self.assertFalse( totally_findable.import_ready, msg="The node class is well defined, but the module is not in the " - "python path so import fails" + "python path so import fails", ) finally: totally_findable.__class__.__module__ = og_module # Fix what you broke @@ -572,17 +552,17 @@ def test_import_ready(self): self.assertTrue( self.comp.import_ready, msg="Sanity check on initial condition -- tests are in the path, so this " - "is importable" + "is importable", ) self.comp.totally_findable = totally_findable self.assertTrue( self.comp.import_ready, - msg="Adding importable children should leave the parent import-ready" + msg="Adding importable children should leave the parent import-ready", ) self.comp.bad_class = bad_class self.assertFalse( self.comp.import_ready, - msg="Adding un-importable children should make the parent not import ready" + msg="Adding un-importable children should make the parent not import ready", ) def test_with_executor(self): @@ -595,13 +575,13 @@ def test_with_executor(self): self.comp.sub_composite.parent, self.comp, msg="After processing a remotely-executed self, the local self should " - "retain its parent" + "retain its parent", ) self.assertIs( self.comp.sub_composite.executor, exe, msg="After processing a remotely-executed self, the local self should " - "retain its executor" + "retain its executor", ) def test_result_serialization(self): @@ -633,9 +613,9 @@ def test_result_serialization(self): self.assertFalse( self.comp.as_path().is_dir(), msg="Actually, we expect cleanup to have removed empty directories up to " - "and including the semantic root's own directory" + "and including the semantic root's own directory", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nodes/test_for_loop.py b/tests/unit/nodes/test_for_loop.py index 5e824c59..e3a3d3fd 100644 --- a/tests/unit/nodes/test_for_loop.py +++ b/tests/unit/nodes/test_for_loop.py @@ -21,7 +21,6 @@ class TestDictionaryToIndexMaps(unittest.TestCase): - def test_no_keys(self): data = {"key": 5} with self.assertRaises(ValueError): @@ -76,7 +75,7 @@ def test_valid_data_nested_and_zipped(self): "nested1": [2, 3], "nested2": [4, 5, 6], "zipped1": [7, 8, 9, 10], - "zipped2": [11, 12, 13, 14, 15] + "zipped2": [11, 12, 13, 14, 15], } nested_keys = ("nested1", "nested2") zipped_keys = ("zipped1", "zipped2") @@ -85,20 +84,20 @@ def test_valid_data_nested_and_zipped(self): nested_keys[0]: n_idx, nested_keys[1]: n_idx2, zipped_keys[0]: z_idx, - zipped_keys[1]: z_idx2 + zipped_keys[1]: z_idx2, } for n_idx, n_idx2 in product( - range(len(data["nested1"])), - range(len(data["nested2"])) + range(len(data["nested1"])), range(len(data["nested2"])) ) for z_idx, z_idx2 in zip( - range(len(data["zipped1"])), - range(len(data["zipped2"])), strict=False + range(len(data["zipped1"])), range(len(data["zipped2"])), strict=False ) ) self.assertEqual( expected_maps, - dictionary_to_index_maps(data, nested_keys=nested_keys, zipped_keys=zipped_keys), + dictionary_to_index_maps( + data, nested_keys=nested_keys, zipped_keys=zipped_keys + ), ) @@ -110,23 +109,29 @@ def FiveTogether( d: int = 3, e: str = "foobar", ): - return (a, b, c, d, e,), + return ( + ( + a, + b, + c, + d, + e, + ), + ) class TestForNode(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: super().setUpClass() ensure_tests_in_python_path() from static.demo_nodes import AddThree + cls.AddThree = AddThree @staticmethod def _get_column( - output: DotDict, - output_as_dataframe: bool, - column_name: str ="together" + output: DotDict, output_as_dataframe: bool, column_name: str = "together" ): """ Facilitate testing different output types @@ -150,39 +155,36 @@ def test_iter_only(self): with self.subTest(f"output_as_dataframe {output_as_dataframe}"): for_instance = for_node( FiveTogether, - iter_on=("a", "b",), + iter_on=( + "a", + "b", + ), a=[42, 43, 44], b=[13, 14], output_as_dataframe=output_as_dataframe, ) self.assertIsNone( - for_instance.nrows, - msg="Haven't run yet, so there is no size" + for_instance.nrows, msg="Haven't run yet, so there is no size" ) self.assertIsNone( - for_instance.ncols, - msg="Haven't run yet, so there is no size" + for_instance.ncols, msg="Haven't run yet, so there is no size" ) out = for_instance(e="iter") self.assertIsInstance( out[list(out.keys())[0]], DataFrame if output_as_dataframe else list, - msg="Expected output type to correspond to boolean request" - ) - self.assertEqual( - for_instance.nrows, - 3 * 2, - msg="Expect nested loops" + msg="Expected output type to correspond to boolean request", ) + self.assertEqual(for_instance.nrows, 3 * 2, msg="Expect nested loops") self.assertEqual( for_instance.ncols, 1 + 2, - msg="Dataframe should only hold output and _looped_ input" + msg="Dataframe should only hold output and _looped_ input", ) self.assertTupleEqual( self._get_column(out, output_as_dataframe=output_as_dataframe)[1], ((42, 14, 2, 3, "iter"),), - msg="Iter should get nested, broadcast broadcast, else take default" + msg="Iter should get nested, broadcast broadcast, else take default", ) def test_zip_only(self): @@ -190,7 +192,10 @@ def test_zip_only(self): with self.subTest(f"output_as_dataframe {output_as_dataframe}"): for_instance = for_node( FiveTogether, - zip_on=("c", "d",), + zip_on=( + "c", + "d", + ), e="zip", output_as_dataframe=output_as_dataframe, ) @@ -199,17 +204,17 @@ def test_zip_only(self): for_instance.nrows, 2, msg="Expect zipping with the python convention of truncating to " - "shortest" + "shortest", ) self.assertEqual( for_instance.ncols, 1 + 2, - msg="Dataframe should only hold output and _looped_ input" + msg="Dataframe should only hold output and _looped_ input", ) self.assertTupleEqual( self._get_column(out, output_as_dataframe)[1], ((0, 1, 101, -2, "zip"),), - msg="Zipped should get zipped, broadcast broadcast, else take default" + msg="Zipped should get zipped, broadcast broadcast, else take default", ) def test_iter_and_zip(self): @@ -217,10 +222,16 @@ def test_iter_and_zip(self): with self.subTest(f"output_as_dataframe {output_as_dataframe}"): for_instance = for_node( FiveTogether, - iter_on=("a", "b",), + iter_on=( + "a", + "b", + ), a=[42, 43, 44], b=[13, 14], - zip_on=("c", "d",), + zip_on=( + "c", + "d", + ), e="both", output_as_dataframe=output_as_dataframe, ) @@ -228,29 +239,29 @@ def test_iter_and_zip(self): self.assertEqual( for_instance.nrows, 3 * 2 * 2, - msg="Zipped stuff is nested with the individually nested fields" + msg="Zipped stuff is nested with the individually nested fields", ) self.assertEqual( for_instance.ncols, 1 + 4, - msg="Dataframe should only hold output and _looped_ input" + msg="Dataframe should only hold output and _looped_ input", ) # We don't actually care if the order of nesting changes, but make sure the # iters are getting nested and zipped stay together self.assertTupleEqual( self._get_column(out, output_as_dataframe)[0], ((42, 13, 100, -1, "both"),), - msg="All start" + msg="All start", ) self.assertTupleEqual( self._get_column(out, output_as_dataframe)[1], ((42, 13, 101, -2, "both"),), - msg="Bump zipped together" + msg="Bump zipped together", ) self.assertTupleEqual( self._get_column(out, output_as_dataframe)[2], ((42, 14, 100, -1, "both"),), - msg="Back to start of zipped, bump _one_ iter" + msg="Back to start of zipped, bump _one_ iter", ) def test_dynamic_length(self): @@ -258,25 +269,27 @@ def test_dynamic_length(self): with self.subTest(f"output_as_dataframe {output_as_dataframe}"): for_instance = for_node( FiveTogether, - iter_on=("a", "b",), + iter_on=( + "a", + "b", + ), a=[42, 43, 44], b=[13, 14], - zip_on=("c", "d",), + zip_on=( + "c", + "d", + ), c=[100, 101], d=[-1, -2, -3], output_as_dataframe=output_as_dataframe, ) for_instance() - self.assertEqual( - for_instance.nrows, - 3 * 2 * 2, - msg="Sanity check" - ) + self.assertEqual(for_instance.nrows, 3 * 2 * 2, msg="Sanity check") for_instance(a=[0], b=[1], c=[2]) self.assertEqual( for_instance.nrows, 1, - msg="Should be able to re-run with different input lengths" + msg="Should be able to re-run with different input lengths", ) def test_column_mapping(self): @@ -288,7 +301,13 @@ def FiveApart( d: int = 3, e: str = "foobar", ): - return a, b, c, d, e, + return ( + a, + b, + c, + d, + e, + ) for output_as_dataframe in [True, False]: with self.subTest(f"output_as_dataframe {output_as_dataframe}"): @@ -306,7 +325,7 @@ def FiveApart( "a": "out_a", "b": "out_b", "c": "out_c", - "d": "out_d" + "d": "out_d", }, output_as_dataframe=output_as_dataframe, ) @@ -315,12 +334,16 @@ def FiveApart( for_instance.ncols, 4 + 5, # loop inputs + outputs msg="When all conflicting names are remapped, we should have no " - "trouble" + "trouble", ) - with self.subTest("Insufficient map"), self.assertRaises( + with ( + self.subTest("Insufficient map"), + self.assertRaises( UnmappedConflictError, - msg="Leaving conflicting channels unmapped should raise an error"): + msg="Leaving conflicting channels unmapped should raise an error", + ), + ): for_node( FiveApart, iter_on=("a", "b"), @@ -334,14 +357,17 @@ def FiveApart( # "a": "out_a", "b": "out_b", "c": "out_c", - "d": "out_d" + "d": "out_d", }, output_as_dataframe=output_as_dataframe, ) - with self.subTest("Excessive map"), self.assertRaises( - MapsToNonexistentOutputError, - msg="Trying to map something that isn't there should raise an error" + with ( + self.subTest("Excessive map"), + self.assertRaises( + MapsToNonexistentOutputError, + msg="Trying to map something that isn't there should raise an error", + ), ): for_node( FiveApart, @@ -357,7 +383,7 @@ def FiveApart( "b": "out_b", "c": "out_c", "d": "out_d", - "not_a_key_on_the_body_node_outputs": "anything" + "not_a_key_on_the_body_node_outputs": "anything", }, output_as_dataframe=output_as_dataframe, ) @@ -376,7 +402,7 @@ def test_body_node_executor(self): n_procs = 4 with ThreadPoolExecutor(max_workers=n_procs) as exe: for_parallel.body_node_executor = exe - for_parallel(t=n_procs*[t_sleep]) + for_parallel(t=n_procs * [t_sleep]) dt = perf_counter() - t_start grace = 1.25 self.assertLess( @@ -384,7 +410,7 @@ def test_body_node_executor(self): grace * t_sleep, msg=f"Parallelization over children should result in faster " f"completion. Expected limit {grace} x {t_sleep} = " - f"{grace * t_sleep} -- got {dt}" + f"{grace * t_sleep} -- got {dt}", ) reloaded = pickle.loads(pickle.dumps(for_parallel)) @@ -392,7 +418,7 @@ def test_body_node_executor(self): reloaded.body_node_executor, msg="Just like regular nodes, until executors can be delayed creators " "instead of actual executor nodes, we need to purge executors from " - "nodes on serialization or the thread lock/queue objects hit us" + "nodes on serialization or the thread lock/queue objects hit us", ) def test_with_connections_dataframe(self): @@ -400,12 +426,13 @@ def test_with_connections_dataframe(self): @as_macro_node def LoopInside(self, x: list, y: int): - self.to_list = inputs_to_list( - length_y, y, y, y - ) + self.to_list = inputs_to_list(length_y, y, y, y) self.loop = for_node( Add, - iter_on=("obj", "other",), + iter_on=( + "obj", + "other", + ), obj=x, other=self.to_list, output_as_dataframe=True, @@ -418,9 +445,7 @@ def LoopInside(self, x: list, y: int): self.assertIsInstance(df, DataFrame) self.assertEqual(length_y * len(x), len(df)) self.assertEqual( - x[0] + y, - df["add"][0], - msg="Just make sure the loop is actually running" + x[0] + y, df["add"][0], msg="Just make sure the loop is actually running" ) x, y = [2, 3], 4 df = li(x, y).loop @@ -428,7 +453,7 @@ def LoopInside(self, x: list, y: int): self.assertEqual( x[-1] + y, df["add"][len(df) - 1], - msg="And make sure that we can vary the length still" + msg="And make sure that we can vary the length still", ) def test_with_connections_list(self): @@ -436,12 +461,13 @@ def test_with_connections_list(self): @as_macro_node def LoopInside(self, x: list, y: int): - self.to_list = inputs_to_list( - length_y, y, y, y - ) + self.to_list = inputs_to_list(length_y, y, y, y) self.loop = for_node( Add, - iter_on=("obj", "other",), + iter_on=( + "obj", + "other", + ), obj=x, other=self.to_list, output_as_dataframe=False, @@ -451,22 +477,20 @@ def LoopInside(self, x: list, y: int): x, y = [1], 2 li = LoopInside([1], 2) - l = li().out + li_out = li().out print(li) - self.assertIsInstance(l, list) - self.assertEqual(length_y * len(x), len(l)) + self.assertIsInstance(li_out, list) + self.assertEqual(length_y * len(x), len(li_out)) self.assertEqual( - x[0] + y, - l[0], - msg="Just make sure the loop is actually running" + x[0] + y, li_out[0], msg="Just make sure the loop is actually running" ) x, y = [2, 3], 4 - l = li(x, y).out - self.assertEqual(length_y * len(x), len(l)) + li_out = li(x, y).out + self.assertEqual(length_y * len(x), len(li_out)) self.assertEqual( x[-1] + y, - l[len(l) - 1], - msg="And make sure that we can vary the length still" + li_out[len(li_out) - 1], + msg="And make sure that we can vary the length still", ) def test_node_access_points(self): @@ -478,7 +502,13 @@ def test_node_access_points(self): self.assertEqual(2 * 2, len(df)) self.assertTupleEqual( df["together"][1][0], - (1, 2, 3, 6, "instance",) + ( + 1, + 2, + 3, + 6, + "instance", + ), ) with self.subTest("Zip"): @@ -487,7 +517,13 @@ def test_node_access_points(self): self.assertEqual(2, len(df)) self.assertTupleEqual( df["together"][1][0], - (1, 2, 4, 6, "instance",) + ( + 1, + 2, + 4, + 6, + "instance", + ), ) def test_shortcut(self): @@ -506,9 +542,9 @@ def test_shortcut(self): out = loop2() self.assertListEqual( out.mul, - [(1+1)*1, (1+2)*2], + [(1 + 1) * 1, (1 + 2) * 2], msg="We should be able to call for_node right from node classes to bypass " - "needing to provide the `body_node_class` argument" + "needing to provide the `body_node_class` argument", ) def test_macro_body(self): @@ -548,10 +584,13 @@ def test_repeated_creation(self): self.assertTrue(n3._output_as_dataframe) def test_executor_deserialization(self): - for title, executor, expected in [ ("Instance", ThreadPoolExecutor(), None), - ("Instructions", (ThreadPoolExecutor, (), {}), (ThreadPoolExecutor, (), {})) + ( + "Instructions", + (ThreadPoolExecutor, (), {}), + (ThreadPoolExecutor, (), {}), + ), ]: with self.subTest(title): n = for_node( @@ -569,7 +608,7 @@ def test_executor_deserialization(self): expected, msg="Executor instances should get removed on " "(de)serialization, but instructions on how to build one " - "should not." + "should not.", ) finally: n.delete_storage() diff --git a/tests/unit/nodes/test_function.py b/tests/unit/nodes/test_function.py index bda09d48..34e21635 100644 --- a/tests/unit/nodes/test_function.py +++ b/tests/unit/nodes/test_function.py @@ -61,7 +61,7 @@ def test_instantiation(self): node.outputs.y.value, 11, msg="Expected the run to update the output -- did the test function" - "change or something?" + "change or something?", ) node = function_node(no_default, 1, y=2, output_labels="output") @@ -69,13 +69,13 @@ def test_instantiation(self): self.assertEqual( no_default(1, 2), node.outputs.output.value, - msg="Nodes should allow input initialization by arg _and_ kwarg" + msg="Nodes should allow input initialization by arg _and_ kwarg", ) node(2, y=3) self.assertEqual( no_default(2, 3), node.outputs.output.value, - msg="Nodes should allow input update on call by arg and kwarg" + msg="Nodes should allow input update on call by arg and kwarg", ) with self.assertRaises(ValueError): @@ -88,7 +88,7 @@ def test_instantiation(self): self.assertIs( node2.inputs.x.connections[0], node.outputs.y, - msg="Should be able to make a connection at initialization" + msg="Should be able to make a connection at initialization", ) node >> node2 node.run() @@ -112,7 +112,7 @@ def test_defaults(self): self.assertFalse( without_defaults.ready, msg="I guess we should test for behaviour and not implementation... Without" - "defaults, the node should not be ready!" + "defaults, the node should not be ready!", ) def test_label_choices(self): @@ -132,23 +132,23 @@ def test_label_choices(self): ) self.assertListEqual(n.outputs.labels, ["its_a_tuple"]) - with self.subTest("Fail on multiple return values"), self.assertRaises(ValueError): + with ( + self.subTest("Fail on multiple return values"), + self.assertRaises(ValueError), + ): # Can't automatically parse output labels from a function with multiple # return expressions function_node(multiple_branches) with self.subTest("Override output label scraping"): with self.assertRaises( - ValueError, - msg="Multiple return branches can't be parsed" + ValueError, msg="Multiple return branches can't be parsed" ): switch = function_node(multiple_branches, output_labels="bool") self.assertListEqual(switch.outputs.labels, ["bool"]) switch = function_node( - multiple_branches, - output_labels="bool", - validate_output_labels=False + multiple_branches, output_labels="bool", validate_output_labels=False ) self.assertListEqual(switch.outputs.labels, ["bool"]) @@ -170,7 +170,7 @@ def bilinear(x, y): bilinear(2, 3).run(), 2 * 3, msg="Children of `Function` should have their `node_function` exposed for " - "use at the class level" + "use at the class level", ) def test_statuses(self): @@ -184,7 +184,7 @@ def test_statuses(self): with self.assertRaises( TypeError, msg="We expect the int+str type error because there were no type hints " - "guarding this function from running with bad data" + "guarding this function from running with bad data", ): n.run() self.assertFalse(n.ready) @@ -208,12 +208,12 @@ def test_call(self): self.assertEqual( node.inputs.x.value, 1, - msg="__call__ should accept args to update input" + msg="__call__ should accept args to update input", ) self.assertEqual( node.inputs.y.value, 2, - msg="__call__ should accept kwargs to update input" + msg="__call__ should accept kwargs to update input", ) self.assertEqual( node.outputs.output.value, 1 + 2 + 1, msg="__call__ should run things" @@ -223,7 +223,7 @@ def test_call(self): self.assertEqual( no_default(3, 2), node.outputs.output.value, - msg="__call__ should allow updating only _some_ input before running" + msg="__call__ should allow updating only _some_ input before running", ) with self.assertRaises(ValueError, msg="Check that bad kwargs raise an error"): @@ -239,14 +239,14 @@ def test_return_value(self): return_on_explicit_run, plus_one(2), msg="On explicit run, the most recent input data should be used and " - "the result should be returned" + "the result should be returned", ) return_on_call = node(1) self.assertEqual( return_on_call, plus_one(1), - msg="Run output should be returned on call" + msg="Run output should be returned on call", # This is a duplicate test, since __call__ just invokes run, but it is # such a core promise that let's just double-check it ) @@ -284,16 +284,16 @@ def plus_one_hinted(x: int = 0) -> int: self.assertFalse( node.connected, msg="The x-input connection should have been copied, but should be " - "removed when the copy fails." + "removed when the copy fails.", ) with self.assertRaises( ConnectionCopyError, msg="An unhinted channel is not a valid connection for a hinted " - "channel, and should raise and exception" + "channel, and should raise and exception", ): hinted_node._copy_connections(to_copy) - hinted_node.disconnect()# Make sure you've got a clean slate + hinted_node.disconnect() # Make sure you've got a clean slate node.disconnect() # Make sure you've got a clean slate with self.subTest("Ensure that failures can be continued past"): @@ -306,13 +306,13 @@ def plus_one_hinted(x: int = 0) -> int: hinted_node.inputs.connected, msg="Without hard failure the copy should be allowed to proceed, but " "we don't actually expect any connections to get copied since the " - "only one available had type hint problems" + "only one available had type hint problems", ) self.assertTrue( hinted_node.outputs.connected, msg="Without hard failure the copy should be allowed to proceed, so " "the output should connect fine since feeding hinted to un-hinted " - "is a-ok" + "is a-ok", ) def test_copy_values(self): @@ -339,29 +339,23 @@ def all_floats(x=1.1, y=1.1, z=1.1, omega=NOT_DATA, extra_there=None) -> float: ref._copy_values(floats) self.assertEqual( - ref.inputs.x.value, - 1.1, - msg="Untyped channels should copy freely" + ref.inputs.x.value, 1.1, msg="Untyped channels should copy freely" ) self.assertEqual( ref.inputs.y.value, 0, - msg="Typed channels should ignore values where the type check fails" + msg="Typed channels should ignore values where the type check fails", ) self.assertEqual( ref.inputs.z.value, 1.1, - msg="Typed channels should copy values that conform to their hint" + msg="Typed channels should copy values that conform to their hint", ) self.assertEqual( - ref.inputs.omega.value, - None, - msg="NOT_DATA should be ignored when copying" + ref.inputs.omega.value, None, msg="NOT_DATA should be ignored when copying" ) self.assertEqual( - ref.outputs.out.value, - 42.1, - msg="Output data should also get copied" + ref.outputs.out.value, 42.1, msg="Output data should also get copied" ) # Note also that these nodes each have extra channels the other doesn't that # are simply ignored @@ -376,15 +370,14 @@ def extra_channel(x=1, y=1, z=1, not_present=42): ref.inputs.x = 0 # Revert the value with self.assertRaises( - ValueCopyError, - msg="Type hint should prevent update when we fail hard" + ValueCopyError, msg="Type hint should prevent update when we fail hard" ): ref._copy_values(floats, fail_hard=True) ref._copy_values(extra) # No problem with self.assertRaises( ValueCopyError, - msg="Missing a channel that holds data is also grounds for failure" + msg="Missing a channel that holds data is also grounds for failure", ): ref._copy_values(extra, fail_hard=True) @@ -395,38 +388,34 @@ def test_easy_output_connection(self): n2.inputs.x = n1 self.assertIn( - n1.outputs.y, n2.inputs.x.connections, + n1.outputs.y, + n2.inputs.x.connections, msg="Single-output functions should be able to make connections between " - "their output and another node's input by passing themselves" + "their output and another node's input by passing themselves", ) n1 >> n2 n1.run() self.assertEqual( - n2.outputs.y.value, 3, + n2.outputs.y.value, + 3, msg="Single-output function connections should pass data just like usual; " - "in this case default->plus_one->plus_one = 1 + 1 +1 = 3" + "in this case default->plus_one->plus_one = 1 + 1 +1 = 3", ) at_instantiation = function_node(plus_one, x=n1) self.assertIn( - n1.outputs.y, at_instantiation.inputs.x.connections, + n1.outputs.y, + at_instantiation.inputs.x.connections, msg="The parsing of Single-output functions' output as a connection should " - "also work from assignment at instantiation" + "also work from assignment at instantiation", ) def test_nested_declaration(self): # It's really just a silly case of running without a parent, where you don't # store references to all the nodes declared node = function_node( - plus_one, - x=function_node( - plus_one, - x=function_node( - plus_one, - x=2 - ) - ) + plus_one, x=function_node(plus_one, x=function_node(plus_one, x=2)) ) self.assertEqual(2 + 1 + 1 + 1, node.pull()) @@ -446,7 +435,7 @@ def returns_foo() -> Foo: self.assertEqual( single_output.connected, False, - msg="Should return the _node_ attribute, not acting on the output channel" + msg="Should return the _node_ attribute, not acting on the output channel", ) injection = single_output[0] # Should pass cleanly, even though it tries to run @@ -455,18 +444,18 @@ def returns_foo() -> Foo: self.assertEqual( single_output.some_attribute.value, # The call runs the dynamic node "exists", - msg="Should fall back to acting on the output channel and creating a node" + msg="Should fall back to acting on the output channel and creating a node", ) self.assertEqual( single_output.connected, True, - msg="Should now be connected to the dynamically created nodes" + msg="Should now be connected to the dynamically created nodes", ) with self.assertRaises( AttributeError, - msg="Aggressive running hits the problem that no such attribute exists" + msg="Aggressive running hits the problem that no such attribute exists", ): injected = single_output.doesnt_exists_anywhere # noqa: F841 # The injected node fails at runtime and generates a recovery file @@ -479,20 +468,18 @@ def returns_foo() -> Foo: p.rmdir() self.assertEqual( - injection(), - True, - msg="Should be able to query injection later" + injection(), True, msg="Should be able to query injection later" ) self.assertEqual( single_output["some other key"].value, False, - msg="Should fall back to looking on the single value" + msg="Should fall back to looking on the single value", ) with self.assertRaises( AttributeError, - msg="Attribute injection should not work for private attributes" + msg="Attribute injection should not work for private attributes", ): single_output._some_nonexistant_private_var # noqa: B018 @@ -507,7 +494,7 @@ def NoReturn(x): {"None": type(None)}, NoReturn.preview_outputs(), msg="Functions without a return value should be permissible, although it " - "is not interesting" + "is not interesting", ) # Honestly, functions with no return should probably be made illegal to # encourage functional setups... @@ -518,43 +505,40 @@ def test_pickle(self): reloaded = pickle.loads(pickle.dumps(n)) self.assertListEqual(n.outputs.labels, reloaded.outputs.labels) self.assertDictEqual( - n.outputs.to_value_dict(), - reloaded.outputs.to_value_dict() + n.outputs.to_value_dict(), reloaded.outputs.to_value_dict() ) def test_decoration(self): with self.subTest("@as_function_node(*output_labels, ...)"): WithDecoratorSignature = as_function_node("z")(plus_one) self.assertTrue( - issubclass(WithDecoratorSignature, Function), - msg="Sanity check" + issubclass(WithDecoratorSignature, Function), msg="Sanity check" ) self.assertListEqual( ["z"], list(WithDecoratorSignature.preview_outputs().keys()), - msg="Decorator should capture new output label" + msg="Decorator should capture new output label", ) with self.subTest("@as_function_node"): WithoutDecoratorSignature = as_function_node(plus_one) self.assertTrue( - issubclass(WithoutDecoratorSignature, Function), - msg="Sanity check" + issubclass(WithoutDecoratorSignature, Function), msg="Sanity check" ) self.assertListEqual( ["y"], # "Default" copied here from the function definition return list(WithoutDecoratorSignature.preview_outputs().keys()), - msg="Decorator should capture new output label" + msg="Decorator should capture new output label", ) with self.assertRaises( MultipleDispatchError, msg="This shouldn't be accessible from a regular decorator usage pattern, " "but make sure that mixing-and-matching argument-free calls and calls " - "directly providing the wrapped node fail cleanly" + "directly providing the wrapped node fail cleanly", ): as_function_node(plus_one, "z") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nodes/test_macro.py b/tests/unit/nodes/test_macro.py index c22adbae..92e7f59c 100644 --- a/tests/unit/nodes/test_macro.py +++ b/tests/unit/nodes/test_macro.py @@ -40,23 +40,22 @@ def SomeNode(x): class TestMacro(unittest.TestCase): - def test_io_independence(self): m = macro_node(add_three_macro, output_labels="three__result") self.assertIsNot( m.inputs.one__x, m.one.inputs.x, - msg="Expect input to be by value, not by reference" + msg="Expect input to be by value, not by reference", ) self.assertIsNot( m.outputs.three__result, m.three.outputs.result, - msg="Expect output to be by value, not by reference" + msg="Expect output to be by value, not by reference", ) self.assertFalse( m.connected, msg="Macro should talk to its children by value links _not_ graph " - "connections" + "connections", ) def test_value_links(self): @@ -64,12 +63,12 @@ def test_value_links(self): self.assertIs( m.one.inputs.x, m.inputs.one__x.value_receiver, - msg="Sanity check that value link exists" + msg="Sanity check that value link exists", ) self.assertIs( m.outputs.three__result, m.three.outputs.result.value_receiver, - msg="Sanity check that value link exists" + msg="Sanity check that value link exists", ) self.assertNotEqual( 42, m.one.inputs.x.value, msg="Sanity check that we start from expected" @@ -77,7 +76,7 @@ def test_value_links(self): self.assertNotEqual( 42, m.three.outputs.result.value, - msg="Sanity check that we start from expected" + msg="Sanity check that we start from expected", ) m.inputs.one__x.value = 0 self.assertEqual( @@ -115,12 +114,12 @@ def only_starting(self, one__x): self.assertEqual( m_auto(one__x=x).three__result, expected, - "DAG macros should run fine without user specification of execution." + "DAG macros should run fine without user specification of execution.", ) self.assertEqual( m_user(one__x=x).three__result, expected, - "Macros should run fine if the user nicely specifies the exeuction graph." + "Macros should run fine if the user nicely specifies the exeuction graph.", ) with self.subTest("Partially specified execution should fail"): @@ -137,7 +136,7 @@ def test_default_label(self): self.assertEqual( m.label, add_three_macro.__name__, - msg="Label should be automatically generated" + msg="Label should be automatically generated", ) label = "custom_name" m2 = macro_node(add_three_macro, label=label, output_labels="three__result") @@ -150,7 +149,7 @@ def test_creation_from_decorator(self): m.outputs.three__result.value, NOT_DATA, msg="Output should be accessible with the usual naming convention, but we " - "have not run yet so there shouldn't be any data" + "have not run yet so there shouldn't be any data", ) input_x = 1 @@ -161,12 +160,12 @@ def test_creation_from_decorator(self): self.assertEqual( out.three__result, expected_value, - msg="Macros should return the output, just like other nodes" + msg="Macros should return the output, just like other nodes", ) self.assertEqual( m.outputs.three__result.value, expected_value, - msg="Macros should get output updated, just like other nodes" + msg="Macros should get output updated, just like other nodes", ) def test_creation_from_subclass(self): @@ -184,21 +183,19 @@ def graph_creator(self, one__x): self.assertEqual( m.outputs.three__result.value, add_one(add_one(add_one(x))), - msg="Subclasses should be able to simply override the graph_creator arg" + msg="Subclasses should be able to simply override the graph_creator arg", ) def test_nesting(self): def nested_macro(self, a__x): self.a = function_node(add_one, a__x) self.b = macro_node( - add_three_macro, - one__x=self.a, - output_labels="three__result" + add_three_macro, one__x=self.a, output_labels="three__result" ) self.c = macro_node( add_three_macro, one__x=self.b.outputs.three__result, - output_labels="three__result" + output_labels="three__result", ) self.d = function_node( add_one, @@ -226,32 +223,27 @@ def test_with_executor(self): self.assertIs( NOT_DATA, macro.outputs.three__result.value, - msg="Sanity check that test is in right starting condition" + msg="Sanity check that test is in right starting condition", ) result = macro.run(one__x=0) self.assertIsInstance( - result, - Future, - msg="Should be running as a parallel process" + result, Future, msg="Should be running as a parallel process" ) self.assertIs( NOT_DATA, downstream.outputs.result.value, msg="Downstream events should not yet have triggered either, we should wait" - "for the callback when the result is ready" + "for the callback when the result is ready", ) returned_nodes = result.result(timeout=120) # Wait for the process to finish sleep(1) - self.assertFalse( - macro.running, - msg="Macro should be done running" - ) + self.assertFalse(macro.running, msg="Macro should be done running") self.assertIsNot( original_one, returned_nodes.one, - msg="Executing in a parallel process should be returning new instances" + msg="Executing in a parallel process should be returning new instances", ) # self.assertIs( # returned_nodes.one, @@ -261,26 +253,26 @@ def test_with_executor(self): self.assertIs( macro, macro.one.parent, - msg="Returned nodes should get the macro as their parent" + msg="Returned nodes should get the macro as their parent", # Once upon a time there was some evidence that this test was failing # stochastically, but I just ran the whole test suite 6 times and this test # 8 times and it always passed fine, so maybe the issue is resolved... ) self.assertIsNone( original_one.parent, - msg="Original nodes should be orphaned" + msg="Original nodes should be orphaned", # Note: At time of writing, this is accomplished in Node.__getstate__, # which feels a bit dangerous... ) self.assertEqual( 0 + 3, macro.outputs.three__result.value, - msg="And of course we expect the calculation to actually run" + msg="And of course we expect the calculation to actually run", ) self.assertIs( downstream.inputs.x.connections[0], macro.outputs.three__result, - msg="The macro output should still be connected to downstream" + msg="The macro output should still be connected to downstream", ) sleep(0.2) # Give a moment for the ran signal to emit and downstream to run # I'm a bit surprised this sleep is necessary @@ -288,14 +280,16 @@ def test_with_executor(self): 0 + 3 + 1, downstream.outputs.result.value, msg="The finishing callback should also fire off the ran signal triggering" - "downstream execution" + "downstream execution", ) macro.executor_shutdown() def test_pulling_from_inside_a_macro(self): upstream = function_node(add_one, x=2) - macro = macro_node(add_three_macro, one__x=upstream, output_labels="three__result") + macro = macro_node( + add_three_macro, one__x=upstream, output_labels="three__result" + ) macro.inputs.one__x = 0 # Set value # Now macro.one.inputs.x has both value and a connection @@ -304,14 +298,14 @@ def test_pulling_from_inside_a_macro(self): macro.two.pull(run_parent_trees_too=False), msg="Without running parent trees, the pulling should only run upstream " "nodes _inside_ the scope of the macro, relying on the explicit input" - "value" + "value", ) self.assertEqual( (2 + 1) + 1 + 1, macro.two.pull(run_parent_trees_too=True), msg="Running with parent trees, the pulling should also run the parents " - "data dependencies first" + "data dependencies first", ) def test_recovery_after_failed_pull(self): @@ -320,6 +314,7 @@ def grab_x_and_run(node): return node.inputs.x.connections + node.signals.input.run.connections with self.subTest("When the local scope has cyclic data flow"): + def cyclic_macro(macro): macro.one = function_node(add_one) macro.two = function_node(add_one, x=macro.one) @@ -338,22 +333,24 @@ def grab_connections(macro): initial_connections = grab_connections(m) with self.assertRaises( - CircularDataFlowError, - msg="Pull should only work for DAG workflows" + CircularDataFlowError, msg="Pull should only work for DAG workflows" ): m.two.pull() self.assertListEqual( initial_labels, list(m.children.keys()), msg="Labels should be restored after failing to pull because of " - "acyclicity" + "acyclicity", ) self.assertTrue( all( - c is ic for (c, ic) in zip(grab_connections(m), initial_connections, strict=False) + c is ic + for (c, ic) in zip( + grab_connections(m), initial_connections, strict=False + ) ), msg="Connections should be restored after failing to pull because of " - "cyclic data flow" + "cyclic data flow", ) with self.subTest("When the parent scope has cyclic data flow"): @@ -366,7 +363,7 @@ def grab_connections(macro): self.assertEqual( 0 + 1 + 1 + (1 + 1 + 1), m.three.pull(run_parent_trees_too=True), - msg="Sanity check, without cyclic data flows pulling here should be ok" + msg="Sanity check, without cyclic data flows pulling here should be ok", ) n1.inputs.x = n2 @@ -374,32 +371,33 @@ def grab_connections(macro): initial_connections = grab_x_and_run(n1) + grab_x_and_run(n2) with self.assertRaises( CircularDataFlowError, - msg="Once the outer scope has circular data flows, pulling should fail" + msg="Once the outer scope has circular data flows, pulling should fail", ): m.three.pull(run_parent_trees_too=True) self.assertTrue( all( c is ic for (c, ic) in zip( - grab_x_and_run(n1) + grab_x_and_run(n2), initial_connections, strict=False + grab_x_and_run(n1) + grab_x_and_run(n2), + initial_connections, + strict=False, ) ), msg="Connections should be restored after failing to pull because of " - "cyclic data flow in the outer scope" + "cyclic data flow in the outer scope", ) self.assertEqual( - "n1", - n1.label, - msg="Labels should get restored in the outer scope" + "n1", n1.label, msg="Labels should get restored in the outer scope" ) self.assertEqual( "one", m.one.label, msg="Labels should not have even gotten perturbed to start with in the" - "inner scope" + "inner scope", ) with self.subTest("When a node breaks upstream"): + def fail_at_zero(x): y = 1 / x return y @@ -411,23 +409,23 @@ def fail_at_zero(x): n_not_used >> n2 # Just here to make sure it gets restored with self.assertRaises( - ZeroDivisionError, - msg="The underlying error should get raised" + ZeroDivisionError, msg="The underlying error should get raised" ): n2.pull() self.assertEqual( "n1", n2.label, - msg="Original labels should get restored on upstream failure" + msg="Original labels should get restored on upstream failure", ) self.assertIs( n_not_used, n2.signals.input.run.connections[0].owner, - msg="Original connections should get restored on upstream failure" + msg="Original connections should get restored on upstream failure", ) def test_efficient_signature_interface(self): with self.subTest("Forked input"): + @as_macro_node("output") def MutlipleUseInput(self, x): self.n1 = self.create.standard.UserInput(x) @@ -439,11 +437,11 @@ def MutlipleUseInput(self, x): 2 + 1, len(m), msg="Signature input that is forked to multiple children should result " - "in the automatic creation of a new node to manage the forking." - + "in the automatic creation of a new node to manage the forking.", ) with self.subTest("Single destination input"): + @as_macro_node("output") def SingleUseInput(self, x): self.n = self.create.standard.UserInput(x) @@ -454,10 +452,11 @@ def SingleUseInput(self, x): 1, len(m), msg=f"Signature input with only one destination should not create an " - f"interface node. Found nodes {m.child_labels}" + f"interface node. Found nodes {m.child_labels}", ) with self.subTest("Mixed input"): + @as_macro_node("output") def MixedUseInput(self, x, y): self.n1 = self.create.standard.UserInput(x) @@ -470,10 +469,11 @@ def MixedUseInput(self, x, y): 3 + 1, len(m), msg=f"Mixing forked and single-use input should not cause problems. " - f"Expected four children but found {m.child_labels}" + f"Expected four children but found {m.child_labels}", ) with self.subTest("Pass through"): + @as_macro_node("output") def PassThrough(self, x): return x @@ -486,27 +486,22 @@ def test_storage_for_modified_macros(self): with self.subTest(backend): try: macro = demo_nodes.AddThree(label="m", x=0) - macro.replace_child( - macro.two, - demo_nodes.AddPlusOne() - ) + macro.replace_child(macro.two, demo_nodes.AddPlusOne()) modified_result = macro() if isinstance(backend, PickleStorage): macro.save(backend) - reloaded = demo_nodes.AddThree( - label="m", autoload=backend - ) + reloaded = demo_nodes.AddThree(label="m", autoload=backend) self.assertDictEqual( modified_result, reloaded.outputs.to_value_dict(), - msg="Updated IO should have been (de)serialized" + msg="Updated IO should have been (de)serialized", ) self.assertSetEqual( set(macro.children.keys()), set(reloaded.children.keys()), - msg="All nodes should have been (de)serialized." + msg="All nodes should have been (de)serialized.", ) self.assertEqual( demo_nodes.AddThree.__name__, @@ -517,22 +512,21 @@ def test_storage_for_modified_macros(self): f"not any sort of technical error -- what other class name " f"would we load? -- but is a deeper problem with saving " f"modified objects that we need ot figure out some better " - f"solution for later." + f"solution for later.", ) rerun = reloaded() self.assertIsInstance( reloaded.two, demo_nodes.AddPlusOne, - msg="pickle instantiates the macro node class, but " "but then uses its serialized state, so we retain " - "the replaced node." + "the replaced node.", ) self.assertDictEqual( modified_result, rerun, - msg="Rerunning re-executes the _replaced_ functionality" + msg="Rerunning re-executes the _replaced_ functionality", ) else: raise ValueError( @@ -552,13 +546,14 @@ def OutputScrapedFromFilteredReturn(macro): self.assertListEqual( ["foo"], list(OutputScrapedFromFilteredReturn.preview_outputs().keys()), - msg="The first, self-like argument, should get stripped from output labels" + msg="The first, self-like argument, should get stripped from output labels", ) with self.assertRaises( ValueError, - msg="Return values with extra dots are not permissible as scraped labels" + msg="Return values with extra dots are not permissible as scraped labels", ): + @as_macro_node def ReturnHasDot(macro): macro.foo = macro.create.standard.UserInput() @@ -571,16 +566,16 @@ def test_pickle(self): self.assertTupleEqual( m.child_labels, reloaded_m.child_labels, - msg="Spot check values are getting reloaded correctly" + msg="Spot check values are getting reloaded correctly", ) self.assertDictEqual( m.outputs.to_value_dict(), reloaded_m.outputs.to_value_dict(), - msg="Spot check values are getting reloaded correctly" + msg="Spot check values are getting reloaded correctly", ) self.assertTrue( reloaded_m.two.connected, - msg="The macro should reload with all its child connections" + msg="The macro should reload with all its child connections", ) self.assertTrue(m.two.connected, msg="Sanity check") @@ -589,13 +584,13 @@ def test_pickle(self): reloaded_two.connected, msg="Children are expected to be de-parenting on serialization, so that if " "we ship them off to another process, they don't drag their whole " - "graph with them" + "graph with them", ) self.assertEqual( m.two.outputs.to_value_dict(), reloaded_two.outputs.to_value_dict(), msg="The remainder of the child node state should be recovering just " - "fine on (de)serialization, this is a spot-check" + "fine on (de)serialization, this is a spot-check", ) def test_autoload(self): @@ -605,6 +600,7 @@ def test_autoload(self): existing_node.save("pickle") try: + @as_macro_node def AutoloadsChildren(self, x): self.some_child = SomeNode(x, autoload="pickle") @@ -616,16 +612,13 @@ def AutoloadsChildren(self, x): msg="Autoloading macro children can result in a child node coming with " "pre-loaded data if the child's label at instantiation results in a " "match with some already-saved node (if the load is compatible). This " - "is almost certainly undesirable" + "is almost certainly undesirable", ) @as_macro_node def AutofailsChildren(self, x): self.some_child = function_node( - add_one, - x, - label=SomeNode.__name__, - autoload="pickle" + add_one, x, label=SomeNode.__name__, autoload="pickle" ) return self.some_child @@ -633,7 +626,7 @@ def AutofailsChildren(self, x): TypeError, msg="When the macro auto-loads a child but the loaded type is not " "compatible with the child type, we will even get an error at macro " - "instantiation time! Autoloading macro children is really not wise." + "instantiation time! Autoloading macro children is really not wise.", ): AutofailsChildren() @@ -648,11 +641,11 @@ def DoesntAutoloadChildren(self, x): msg="Despite having the same label as a saved node at instantiation time, " "without autoloading children, our macro safely gets a fresh instance. " "Since this is clearly preferable, here we leave autoload to take its " - "default value (which for macros should thus not autoload.)" + "default value (which for macros should thus not autoload.)", ) finally: existing_node.delete_storage("pickle") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/nodes/test_transform.py b/tests/unit/nodes/test_transform.py index 8381a24e..a26ceb1d 100644 --- a/tests/unit/nodes/test_transform.py +++ b/tests/unit/nodes/test_transform.py @@ -22,6 +22,7 @@ class MyData: stuff: bool = False + @as_function_node def Downstream(x: MyData.dataclass): x.stuff = True @@ -31,16 +32,12 @@ def Downstream(x: MyData.dataclass): class TestTransformer(unittest.TestCase): def test_pickle(self): n = inputs_to_list(3, "a", "b", "c", autorun=True) - self.assertListEqual( - ["a", "b", "c"], - n.outputs.list.value, - msg="Sanity check" - ) + self.assertListEqual(["a", "b", "c"], n.outputs.list.value, msg="Sanity check") reloaded = pickle.loads(pickle.dumps(n)) self.assertListEqual( n.outputs.list.value, reloaded.outputs.list.value, - msg="Transformer nodes should be (un)pickleable" + msg="Transformer nodes should be (un)pickleable", ) self.assertIsInstance(reloaded, Transformer) @@ -49,9 +46,9 @@ def test_inputs_to_list(self): self.assertListEqual(["a", "b", "c"], n.outputs.list.value) def test_list_to_outputs(self): - l = ["a", "b", "c", "d", "e"] - n = list_to_outputs(5, l, autorun=True) - self.assertEqual(l, n.outputs.to_list()) + lst = ["a", "b", "c", "d", "e"] + n = list_to_outputs(len(lst), lst, autorun=True) + self.assertEqual(lst, n.outputs.to_list()) def test_inputs_to_dict(self): with self.subTest("List specification"): @@ -60,7 +57,7 @@ def test_inputs_to_dict(self): self.assertDictEqual( d, n.outputs.dict.value, - msg="Verify structure and ability to pass kwargs" + msg="Verify structure and ability to pass kwargs", ) with self.subTest("Dict specification"): @@ -72,27 +69,25 @@ def test_inputs_to_dict(self): self.assertIs( n.inputs[list(d.keys())[0]].type_hint, hint, - msg="Spot check hint recognition" + msg="Spot check hint recognition", ) self.assertDictEqual( {k: default for k in d}, n.outputs.dict.value, - msg="Verify structure and ability to pass defaults" + msg="Verify structure and ability to pass defaults", ) with self.subTest("Explicit suffix"): suffix = "MyName" n = inputs_to_dict(["c1", "c2"], class_name_suffix="MyName") - self.assertTrue( - n.__class__.__name__.endswith(suffix) - ) + self.assertTrue(n.__class__.__name__.endswith(suffix)) with self.subTest("Only hashable"): unhashable_spec = {"c1": (list, ["an item"])} with self.assertRaises( ValueError, msg="List instances are not hashable, we should not be able to auto-" - "generate a class name from this." + "generate a class name from this.", ): inputs_to_dict(unhashable_spec) @@ -103,21 +98,17 @@ def test_inputs_to_dict(self): self.assertListEqual(unhashable_spec[key][1], n.inputs[key].value) def test_inputs_to_dataframe(self): - l = 3 - n = inputs_to_dataframe(l) + length = 3 + n = inputs_to_dataframe(length) n.recovery = None # Some tests intentionally fail, and we don't want a file - for i in range(l): - n.inputs[f"row_{i}"] = {"x": i, "xsq": i*i} + for i in range(length): + n.inputs[f"row_{i}"] = {"x": i, "xsq": i * i} n() - self.assertIsInstance( - n.outputs.df.value, - DataFrame, - msg="Confirm output type" - ) + self.assertIsInstance(n.outputs.df.value, DataFrame, msg="Confirm output type") self.assertListEqual( - [i*i for i in range(3)], + [i * i for i in range(length)], n.outputs.df.value["xsq"].to_list(), - msg="Spot check values" + msg="Spot check values", ) d1 = {"a": 1, "b": 1} @@ -125,17 +116,17 @@ def test_inputs_to_dataframe(self): with self.assertRaises( KeyError, msg="If the input rows don't have commensurate keys, we expect to get the " - "relevant pandas error" + "relevant pandas error", ): n(row_0=d1, row_1=d1, row_2=d2) - n = inputs_to_dataframe(l) # Freshly instantiate to remove failed status + n = inputs_to_dataframe(length) # Freshly instantiate to remove failed status n.recovery = None # Next test intentionally fails, and we don't want a file d3 = {"a": 1} with self.assertRaises( ValueError, msg="If the input rows don't have commensurate length, we expect to get " - "the relevant pandas error" + "the relevant pandas error", ): n(row_0=d1, row_1=d3, row_2=d1) @@ -152,19 +143,18 @@ def some_generator(): class DC: """Doesn't even have to be an actual dataclass, just dataclass-like""" + necessary: str with_default: int = 42 with_factory: list = field(default_factory=some_generator) n = dataclass_node(DC, label="direct_instance") self.assertIs( - n.dataclass, - DC, - msg="Underlying dataclass should be accessible" + n.dataclass, DC, msg="Underlying dataclass should be accessible" ) self.assertTrue( is_dataclass(n.dataclass), - msg="Underlying dataclass should be a real dataclass" + msg="Underlying dataclass should be a real dataclass", ) self.assertTrue( is_dataclass(DC), @@ -173,49 +163,47 @@ class DC: "too is now a real dataclass, even though it wasn't defined as " "one! This is just a side effect. I don't see it being harmful, " "but in case it gives some future reader trouble, I want to " - "explicitly note the side effect here in the tests." + "explicitly note the side effect here in the tests.", ) self.assertListEqual( list(DC.__dataclass_fields__.keys()), n.inputs.labels, - msg="Inputs should correspond exactly to fields" + msg="Inputs should correspond exactly to fields", ) self.assertIs( DC, n.outputs.dataclass.type_hint, - msg="Output type hint should get automatically set" + msg="Output type hint should get automatically set", ) key = random.choice(n.inputs.labels) self.assertIs( DC.__dataclass_fields__[key].type, n.inputs[key].type_hint, - msg="Spot-check input type hints are pulled from dataclass fields" + msg="Spot-check input type hints are pulled from dataclass fields", ) self.assertFalse( n.inputs.necessary.ready, - msg="Fields with no default and no default factory should not be ready" + msg="Fields with no default and no default factory should not be ready", ) self.assertTrue( - n.inputs.with_default.ready, - msg="Fields with default should be ready" + n.inputs.with_default.ready, msg="Fields with default should be ready" ) self.assertTrue( n.inputs.with_factory.ready, - msg="Fields with default factory should be ready" + msg="Fields with default factory should be ready", ) self.assertListEqual( n.inputs.with_factory.value, some_generator(), - msg="Verify the generator is being used to set the intial value" + msg="Verify the generator is being used to set the intial value", ) out = n(necessary="something") self.assertIsInstance( - out, - DC, - msg="Node should output an instance of the dataclass" + out, DC, msg="Node should output an instance of the dataclass" ) with self.subTest("From decorator"): + @as_dataclass_node @dataclass class DecoratedDC: @@ -231,66 +219,58 @@ class DecoratedDCLike: for n_cls, style in zip( [DecoratedDC(label="dcinst"), DecoratedDCLike(label="dcinst")], - ["Actual dataclass", "Dataclass-like class"], strict=False + ["Actual dataclass", "Dataclass-like class"], + strict=False, ): with self.subTest(style): self.assertTrue( is_dataclass(n_cls.dataclass), - msg="Underlying dataclass should be available on node class" + msg="Underlying dataclass should be available on node class", ) prev = n_cls.preview_inputs() key = random.choice(list(prev.keys())) self.assertIs( n_cls._dataclass_fields[key].type, prev[key][0], - msg="Spot-check input type hints are pulled from dataclass fields" + msg="Spot-check input type hints are pulled from dataclass fields", ) self.assertIs( - prev["necessary"][1], - NOT_DATA, - msg="Field has no default" + prev["necessary"][1], NOT_DATA, msg="Field has no default" ) self.assertEqual( n_cls._dataclass_fields["with_default"].default, prev["with_default"][1], - msg="Fields with default should get scraped" + msg="Fields with default should get scraped", ) self.assertIs( prev["with_factory"][1], NOT_DATA, msg="Fields with default factory won't see their default until " - "instantiation" + "instantiation", ) def test_dataclass_typing_and_storage(self): md = MyData() - with self.assertRaises( - TypeError, - msg="Wrongly typed input should not connect" - ): + with self.assertRaises(TypeError, msg="Wrongly typed input should not connect"): Downstream(5) ds = Downstream(md) out = ds.pull() - self.assertTrue( - out.stuff, - msg="Sanity check" - ) + self.assertTrue(out.stuff, msg="Sanity check") rmd = pickle.loads(pickle.dumps(md)) self.assertIs( rmd.outputs.dataclass.type_hint, MyData.dataclass, - msg="Type hint should be findable on the scope of the node decorating it" + msg="Type hint should be findable on the scope of the node decorating it", ) ds2 = Downstream(rmd) out = ds2.pull() self.assertTrue( - out.stuff, - msg="Flow should be able to survive (de)serialization" + out.stuff, msg="Flow should be able to survive (de)serialization" ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index d3b778a2..fda17d62 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -32,6 +32,7 @@ def data_input_locked(self): class InputChannel(Channel): """Just to de-abstract the base class""" + def __str__(self): return "non-abstract input" @@ -42,6 +43,7 @@ def connection_partner_type(self) -> type[Channel]: class OutputChannel(Channel): """Just to de-abstract the base class""" + def __str__(self): return "non-abstract output" @@ -51,22 +53,17 @@ def connection_partner_type(self) -> type[Channel]: class TestChannel(unittest.TestCase): - def setUp(self) -> None: self.inp = InputChannel("inp", DummyOwner()) self.out = OutputChannel("out", DummyOwner()) self.out2 = OutputChannel("out2", DummyOwner()) def test_connection_validity(self): - with self.assertRaises( - TypeError, - msg="Can't connect to non-channels" - ): + with self.assertRaises(TypeError, msg="Can't connect to non-channels"): self.inp.connect("not an owner") with self.assertRaises( - TypeError, - msg="Can't connect to channels that are not the partner type" + TypeError, msg="Can't connect to channels that are not the partner type" ): self.inp.connect(InputChannel("also_input", DummyOwner())) @@ -79,25 +76,21 @@ def test_connection_reflexivity(self): self.assertIs( self.inp.connections[0], self.out, - msg="Connecting a conjugate pair should work fine" + msg="Connecting a conjugate pair should work fine", ) self.assertIs( - self.out.connections[0], - self.inp, - msg="Promised connection to be reflexive" + self.out.connections[0], self.inp, msg="Promised connection to be reflexive" ) self.out.disconnect_all() self.assertListEqual( - [], - self.inp.connections, - msg="Promised disconnection to be reflexive too" + [], self.inp.connections, msg="Promised disconnection to be reflexive too" ) self.out.connect(self.inp) self.assertIs( self.inp.connections[0], self.out, - msg="Connecting should work in either direction" + msg="Connecting should work in either direction", ) def test_connect_and_disconnect(self): @@ -108,7 +101,7 @@ def test_connect_and_disconnect(self): [(self.inp, self.out2), (self.inp, self.out)], disconnected, msg="Broken connection pairs should be returned in the order they were " - "broken" + "broken", ) def test_iterability(self): @@ -118,12 +111,11 @@ def test_iterability(self): self.assertIs( self.inp.connections[i], conn, - msg="Promised channels to be iterable over connections" + msg="Promised channels to be iterable over connections", ) class TestDataChannels(unittest.TestCase): - def setUp(self) -> None: self.ni1 = InputData( label="numeric", owner=DummyOwner(), default=1, type_hint=int | float @@ -151,7 +143,7 @@ def test_mutable_defaults(self): self.assertEqual( len(so2.default), len(self.so1.default) - 1, - msg="Mutable defaults should avoid sharing between different instances" + msg="Mutable defaults should avoid sharing between different instances", ) def test_fetch(self): @@ -162,16 +154,14 @@ def test_fetch(self): self.ni1.connect(self.no_empty) self.assertEqual( - self.ni1.value, - 1, - msg="Data should not be getting pushed on connection" + self.ni1.value, 1, msg="Data should not be getting pushed on connection" ) self.ni1.fetch() self.assertEqual( self.ni1.value, 1, - msg="NOT_DATA values should not be getting pulled, so no update expected" + msg="NOT_DATA values should not be getting pulled, so no update expected", ) self.no.value = 3 @@ -180,7 +170,7 @@ def test_fetch(self): self.ni1.value, 3, msg="Data fetch should to first connected value that's actually data," - "in this case skipping over no_empty" + "in this case skipping over no_empty", ) self.no_empty.value = 4 @@ -189,7 +179,7 @@ def test_fetch(self): self.ni1.value, 4, msg="As soon as no_empty actually has data, it's position as 0th " - "element in the connections list should give it priority" + "element in the connections list should give it priority", ) def test_connection_validity(self): @@ -200,18 +190,18 @@ def test_connection_validity(self): self.assertIn( self.no, self.ni1.connections, - msg="Input types should be allowed to be a super-set of output types" + msg="Input types should be allowed to be a super-set of output types", ) with self.assertRaises( ChannelConnectionError, - msg="Input types should not be allowed to be a sub-set of output types" + msg="Input types should not be allowed to be a sub-set of output types", ): self.no.connect(self.ni2) with self.assertRaises( ChannelConnectionError, - msg="Totally different type hints should not allow connections" + msg="Totally different type hints should not allow connections", ): self.so1.connect(self.ni2) @@ -220,7 +210,7 @@ def test_connection_validity(self): self.assertIn( self.so1, self.ni2.connections, - msg="With strict connections turned off, we should allow type-violations" + msg="With strict connections turned off, we should allow type-violations", ) def test_copy_connections(self): @@ -230,7 +220,7 @@ def test_copy_connections(self): self.assertListEqual( self.ni2.connections, [*self.ni1.connections, self.no_empty], - msg="Copying should be additive, existing connections should still be there" + msg="Copying should be additive, existing connections should still be there", ) self.ni2.disconnect(*self.ni1.connections) @@ -238,14 +228,14 @@ def test_copy_connections(self): with self.assertRaises( ChannelConnectionError, msg="Should not be able to connect to so1 because of type hint " - "incompatibility" + "incompatibility", ): self.ni2.copy_connections(self.ni1) self.assertListEqual( self.ni2.connections, [self.no_empty], msg="On failing, copy should revert the copying channel to its orignial " - "state" + "state", ) def test_value_receiver(self): @@ -260,25 +250,23 @@ def test_value_receiver(self): self.assertEqual( new_value, self.ni2.value, - msg="Value-linked owners should automatically get new values" + msg="Value-linked owners should automatically get new values", ) self.ni2.value = 3 self.assertEqual( self.ni1.value, new_value, - msg="Coupling is uni-directional, the partner should not push values back" + msg="Coupling is uni-directional, the partner should not push values back", ) with self.assertRaises( - TypeError, - msg="Only data channels of the same class are valid partners" + TypeError, msg="Only data channels of the same class are valid partners" ): self.ni1.value_receiver = self.no with self.assertRaises( - ValueError, - msg="Must not couple to self to avoid infinite recursion" + ValueError, msg="Must not couple to self to avoid infinite recursion" ): self.ni1.value_receiver = self.ni1 @@ -305,14 +293,13 @@ def test_value_assignment(self): self.ni1.owner.locked = True with self.assertRaises( RuntimeError, - msg="Input data should be locked while its owner has data_input_locked" + msg="Input data should be locked while its owner has data_input_locked", ): self.ni1.value = 3 self.ni1.owner.locked = False with self.assertRaises( - TypeError, - msg="Should not be able to take values of the wrong type" + TypeError, msg="Should not be able to take values of the wrong type" ): self.ni2.value = [2] @@ -330,12 +317,12 @@ def test_ready(self): without_default.value, NOT_DATA, msg=f"Without a default, spec is to have a NOT_DATA value but got " - f"{type(without_default.value)}" + f"{type(without_default.value)}", ) self.assertFalse( without_default.ready, msg="Even without type hints, readiness should be false when the value " - "is NOT_DATA" + "is NOT_DATA", ) self.ni1.value = 1 @@ -347,7 +334,7 @@ def test_ready(self): self.ni1.strict_hints = False self.assertTrue( self.ni1.ready, - msg="Without checking the hint, we should only car that there's data" + msg="Without checking the hint, we should only car that there's data", ) def test_if_not_data(self): @@ -380,7 +367,9 @@ def test_connections(self): self.assertEqual(len(self.out.connections), 0) with self.subTest("No connections to non-SignalChannels"): - bad = InputData(label="numeric", owner=DummyOwner(), default=1, type_hint=int) + bad = InputData( + label="numeric", owner=DummyOwner(), default=1, type_hint=int + ) with self.assertRaises(TypeError): self.inp.connect(bad) @@ -403,7 +392,7 @@ def test_aggregating_call(self): with self.assertRaises( TypeError, msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional" + "receiving an output signal is not optional", ): agg() @@ -411,55 +400,43 @@ def test_aggregating_call(self): agg.connect(self.out, out2) self.assertEqual( - 2, - len(agg.connections), - msg="Sanity check on initial conditions" + 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, - len(agg.received_signals), - msg="Sanity check on initial conditions" - ) - self.assertListEqual( - [0], - owner.foo, - msg="Sanity check on initial conditions" + 0, len(agg.received_signals), msg="Sanity check on initial conditions" ) + self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") self.out() - self.assertEqual( - 1, - len(agg.received_signals), - msg="Signal should be received" - ) + self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") self.assertListEqual( [0], owner.foo, - msg="Receiving only _one_ of your connections should not fire the callback" + msg="Receiving only _one_ of your connections should not fire the callback", ) self.out() self.assertEqual( 1, len(agg.received_signals), - msg="Repeatedly receiving the same signal should have no effect" + msg="Repeatedly receiving the same signal should have no effect", ) self.assertListEqual( [0], owner.foo, - msg="Repeatedly receiving the same signal should have no effect" + msg="Repeatedly receiving the same signal should have no effect", ) out2() self.assertListEqual( [0, 1], owner.foo, - msg="After 2/2 output signals have fired, the callback should fire" + msg="After 2/2 output signals have fired, the callback should fire", ) self.assertEqual( 0, len(agg.received_signals), - msg="Firing the callback should reset the list of received signals" + msg="Firing the callback should reset the list of received signals", ) out2() @@ -469,12 +446,12 @@ def test_aggregating_call(self): [0, 1, 2], owner.foo, msg="Having a vestigial received signal (i.e. one from an output signal " - "that is no longer connected) shouldn't hurt anything" + "that is no longer connected) shouldn't hurt anything", ) self.assertEqual( 0, len(agg.received_signals), - msg="All signals, including vestigial ones, should get cleared on call" + msg="All signals, including vestigial ones, should get cleared on call", ) def test_callbacks(self): @@ -510,7 +487,7 @@ def doesnt_belong_to_owner(): owner.update, owner.method_with_only_kwargs, owner.staticmethod_without_args, - owner.classmethod_without_args + owner.classmethod_without_args, ]: with self.subTest(callback.__name__): InputSignal(label="inp", owner=owner, callback=callback) @@ -522,9 +499,12 @@ def doesnt_belong_to_owner(): owner.classmethod_with_args, doesnt_belong_to_owner, ]: - with self.subTest(callback.__name__), self.assertRaises(BadCallbackError): + with ( + self.subTest(callback.__name__), + self.assertRaises(BadCallbackError), + ): InputSignal(label="inp", owner=owner, callback=callback) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_create.py b/tests/unit/test_create.py index 157ea5c7..1bec99d2 100644 --- a/tests/unit/test_create.py +++ b/tests/unit/test_create.py @@ -8,5 +8,5 @@ def test_instantiate(self): Creator() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_docs.py b/tests/unit/test_docs.py index 3b730f79..b8d88e3f 100644 --- a/tests/unit/test_docs.py +++ b/tests/unit/test_docs.py @@ -7,7 +7,7 @@ def load_tests(loader, tests, ignore): for _importer, name, _ispkg in pkgutil.walk_packages( - pyiron_workflow.__path__, pyiron_workflow.__name__ + '.' + pyiron_workflow.__path__, pyiron_workflow.__name__ + "." ): tests.addTests(doctest.DocTestSuite(name)) return tests @@ -22,5 +22,5 @@ def test_void(self): pass -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_find.py b/tests/unit/test_find.py index b638e441..812f78ad 100644 --- a/tests/unit/test_find.py +++ b/tests/unit/test_find.py @@ -27,19 +27,20 @@ def test_find_nodes(self): ensure_tests_in_python_path() from static import demo_nodes + found_by_module = find_nodes(demo_nodes) self.assertListEqual( [o.__name__ for o in found_by_path], [o.__name__ for o in found_by_string], msg=f"You should find the same thing regardless of source representation;" - f"by path got {found_by_path} and by string got {found_by_string}" + f"by path got {found_by_path} and by string got {found_by_string}", ) self.assertListEqual( [o.__name__ for o in found_by_string], [o.__name__ for o in found_by_module], msg=f"You should find the same thing regardless of source representation;" - f"by string got {found_by_string} and by module got {found_by_module}" + f"by string got {found_by_string} and by module got {found_by_module}", ) self.assertListEqual( [o.__name__ for o in found_by_string], @@ -47,13 +48,13 @@ def test_find_nodes(self): demo_nodes.AddPlusOne.__name__, demo_nodes.AddThree.__name__, demo_nodes.Dynamic.__name__, - demo_nodes.OptionallyAdd.__name__ + demo_nodes.OptionallyAdd.__name__, ], msg=f"Should match a hand-selected expectation list that ignores the " f"private and non-local nodes. If you update the demo nodes this may " - f"fail and need to be trivially updated. Got {found_by_module}" + f"fail and need to be trivially updated. Got {found_by_module}", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index 639a40de..614f1fef 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -18,7 +18,6 @@ class Dummy(HasIO): - def __init__(self, label: str | None = "has_io"): super().__init__() self._label = label @@ -49,12 +48,11 @@ def data_input_locked(self): class TestDataIO(unittest.TestCase): - def setUp(self) -> None: has_io = Dummy() self.inputs = [ - InputData(label="x", owner=has_io, default=0., type_hint=float), - InputData(label="y", owner=has_io, default=1., type_hint=float) + InputData(label="x", owner=has_io, default=0.0, type_hint=float), + InputData(label="y", owner=has_io, default=1.0, type_hint=float), ] outputs = [ OutputData(label="a", owner=has_io, type_hint=float), @@ -85,46 +83,42 @@ def test_assignment(self): self.assertIs( self.output.not_this_channels_name, self.post_facto_output, - msg="Expected channel to get assigned" + msg="Expected channel to get assigned", ) self.assertEqual( self.post_facto_output.label, label_before_assignment, msg="Labels should not get updated on assignment of channels to IO " - "collections" + "collections", ) def test_connection(self): with self.assertRaises( - TypeError, - msg="Shouldn't be allowed to connect two inputs" + TypeError, msg="Shouldn't be allowed to connect two inputs" ): self.input.x = self.input.y self.assertEqual( 0, len(self.input.x.connections), - msg="Sanity check that the above error-raising connection never got made" + msg="Sanity check that the above error-raising connection never got made", ) self.input.x = self.output.a self.assertIn( self.input.x, self.output.a.connections, - msg="Should be able to create connections by assignment" + msg="Should be able to create connections by assignment", ) - self.input.x = 7. - self.assertEqual(self.input.x.value, 7.) + self.input.x = 7.0 + self.assertEqual(self.input.x.value, 7.0) self.input.y = self.output.a disconnected = self.input.disconnect() self.assertListEqual( disconnected, - [ - (self.input.x, self.output.a), - (self.input.y, self.output.a) - ], - msg="Disconnecting the panel should disconnect all children" + [(self.input.x, self.output.a), (self.input.y, self.output.a)], + msg="Disconnecting the panel should disconnect all children", ) def test_conversion(self): @@ -134,7 +128,7 @@ def test_conversion(self): self.assertEqual( len(self.inputs), len(converted), - msg="And it shouldn't have any extra items either" + msg="And it shouldn't have any extra items either", ) def test_iteration(self): @@ -144,12 +138,12 @@ def test_connections_property(self): self.assertEqual( len(self.input.connections), 0, - msg="Sanity check expectations about self.input" + msg="Sanity check expectations about self.input", ) self.assertEqual( len(self.output.connections), 0, - msg="Sanity check expectations about self.input" + msg="Sanity check expectations about self.input", ) for inp in self.input: @@ -158,27 +152,27 @@ def test_connections_property(self): self.assertEqual( len(self.output.connections), len(self.input), - msg="Expected to find all the channels in the input" + msg="Expected to find all the channels in the input", ) self.assertEqual( len(self.input.connections), 1, - msg="Each unique connection should appear only once" + msg="Each unique connection should appear only once", ) self.assertIs( self.input.connections[0], self.input.x.connections[0], msg="The IO connection found should be the same object as the channel " - "connection" + "connection", ) def test_to_list(self): self.assertListEqual( - [0., 1.], + [0.0, 1.0], self.input.to_list(), msg="Expected a shortcut to channel values. Order is explicitly not " "guaranteed in the docstring, but it would be nice to appear in the " - "order the channels are added here" + "order the channels are added here", ) @@ -208,21 +202,21 @@ def test_disconnect(self): self.assertEqual( 4, len(self.signals.disconnect()), - msg="Disconnect should disconnect all on panels and the Signals super-panel" + msg="Disconnect should disconnect all on panels and the Signals super-panel", ) def test_disconnect_run(self): self.assertEqual( 2, len(self.signals.disconnect_run()), - msg="Should disconnect exactly everything connected to run" + msg="Should disconnect exactly everything connected to run", ) no_run_signals = Signals() self.assertEqual( 0, len(no_run_signals.disconnect_run()), - msg="If there is no run channel, the list of disconnections should be empty" + msg="If there is no run channel, the list of disconnections should be empty", ) @@ -242,37 +236,30 @@ def test_set_input_values(self): self.assertDictEqual( {"input_channel": "v1", "more_input": "v2"}, has_io.inputs.to_value_dict(), - msg="Args should be set by order of channel appearance" + msg="Args should be set by order of channel appearance", ) has_io.set_input_values(more_input="v4", input_channel="v3") self.assertDictEqual( {"input_channel": "v3", "more_input": "v4"}, has_io.inputs.to_value_dict(), - msg="Kwargs should be set by key-label matching" + msg="Kwargs should be set by key-label matching", ) has_io.set_input_values("v5", more_input="v6") self.assertDictEqual( {"input_channel": "v5", "more_input": "v6"}, has_io.inputs.to_value_dict(), - msg="Mixing and matching args and kwargs is permissible" + msg="Mixing and matching args and kwargs is permissible", ) - with self.assertRaises( - ValueError, - msg="More args than channels is disallowed" - ): + with self.assertRaises(ValueError, msg="More args than channels is disallowed"): has_io.set_input_values(1, 2, 3) with self.assertRaises( - ValueError, - msg="A channel updating from both args and kwargs is disallowed" + ValueError, msg="A channel updating from both args and kwargs is disallowed" ): has_io.set_input_values(1, input_channel=2) - with self.assertRaises( - ValueError, - msg="Kwargs not among input is disallowed" - ): + with self.assertRaises(ValueError, msg="Kwargs not among input is disallowed"): has_io.set_input_values(not_a_channel=42) def test_connected_and_disconnect(self): @@ -281,30 +268,26 @@ def test_connected_and_disconnect(self): has_io1 >> has_io2 self.assertTrue( has_io1.connected, - msg="Any connection should result in a positive connected status" + msg="Any connection should result in a positive connected status", ) has_io1.disconnect() self.assertFalse( - has_io1.connected, - msg="Disconnect should break all connections" + has_io1.connected, msg="Disconnect should break all connections" ) def test_strict_hints(self): has_io = Dummy() has_io.inputs["input_channel"] = InputData("input_channel", has_io) - self.assertTrue( - has_io.inputs.input_channel.strict_hints, - msg="Sanity check" - ) + self.assertTrue(has_io.inputs.input_channel.strict_hints, msg="Sanity check") has_io.deactivate_strict_hints() self.assertFalse( has_io.inputs.input_channel.strict_hints, - msg="Hint strictness should be accessible from the top level" + msg="Hint strictness should be accessible from the top level", ) has_io.activate_strict_hints() self.assertTrue( has_io.inputs.input_channel.strict_hints, - msg="Hint strictness should be accessible from the top level" + msg="Hint strictness should be accessible from the top level", ) def test_rshift_operator(self): @@ -314,7 +297,7 @@ def test_rshift_operator(self): self.assertIn( has_io1.signals.output.ran, has_io2.signals.input.run.connections, - msg="Right shift should be syntactic sugar for an 'or' run connection" + msg="Right shift should be syntactic sugar for an 'or' run connection", ) def test_lshift_operator(self): @@ -324,7 +307,7 @@ def test_lshift_operator(self): self.assertIn( has_io1.signals.input.accumulate_and_run, has_io2.signals.output.ran.connections, - msg="Left shift should be syntactic sugar for an 'and' run connection" + msg="Left shift should be syntactic sugar for an 'and' run connection", ) has_io1.disconnect() @@ -334,7 +317,7 @@ def test_lshift_operator(self): self.assertListEqual( [has_io3.signals.output.ran, has_io2.signals.output.ran], has_io1.signals.input.accumulate_and_run.connections, - msg="Left shift should accommodate groups of connections" + msg="Left shift should accommodate groups of connections", ) def test_copy_io(self): @@ -355,10 +338,14 @@ def test_copy_io(self): to_copy.outputs["used_output"] = OutputData("used_output", to_copy) to_copy.outputs["unused_output"] = OutputData("unused_output", to_copy) to_copy.signals.input["custom_signal"] = InputSignal( - "custom_signal", to_copy, to_copy.update, + "custom_signal", + to_copy, + to_copy.update, ) to_copy.signals.input["unused_signal"] = InputSignal( - "unused_signal", to_copy, to_copy.update, + "unused_signal", + to_copy, + to_copy.update, ) downstream = Dummy(label="downstream") @@ -376,7 +363,7 @@ def test_copy_io(self): with self.assertRaises( ConnectionCopyError, msg="The copier is missing all sorts of connected channels and should " - "fail to copy" + "fail to copy", ): copier.copy_io( to_copy, connections_fail_hard=True, values_fail_hard=False @@ -384,51 +371,52 @@ def test_copy_io(self): self.assertFalse( copier.connected, msg="After a failure, any connections that _were_ made should get " - "reset" + "reset", ) with self.subTest("Force missing connections"): - copier.copy_io( - to_copy, connections_fail_hard=False, values_fail_hard=False - ) + copier.copy_io(to_copy, connections_fail_hard=False, values_fail_hard=False) self.assertIn( copier.signals.output.ran, downstream.signals.input.run, - msg="The channel that _can_ get copied _should_ get copied" + msg="The channel that _can_ get copied _should_ get copied", ) copier.signals.output.ran.disconnect_all() self.assertFalse( copier.connected, - msg="Sanity check that that was indeed the only connection" + msg="Sanity check that that was indeed the only connection", ) copier.inputs["used_input"] = InputData("used_input", copier) copier.inputs["hinted_input"] = InputData( - "hinted_input", copier, type_hint=str # Different hint! + "hinted_input", + copier, + type_hint=str, # Different hint! ) copier.inputs["extra_input"] = InputData( "extra_input", copier, default="not on the copied object but that's ok" ) copier.outputs["used_output"] = OutputData("used_output", copier) copier.signals.input["custom_signal"] = InputSignal( - "custom_signal", copier, copier.update, + "custom_signal", + copier, + copier.update, ) - with self.subTest("Bad hint causes connection error"),self.assertRaises( + with ( + self.subTest("Bad hint causes connection error"), + self.assertRaises( ConnectionCopyError, msg="Can't connect channels with incommensurate type hints", + ), ): - copier.copy_io( - to_copy, connections_fail_hard=True, values_fail_hard=False - ) + copier.copy_io(to_copy, connections_fail_hard=True, values_fail_hard=False) # Bring the copier's type hint in-line with the object being copied copier.inputs.hinted_input.type_hint = float with self.subTest("Passes missing values"): - copier.copy_io( - to_copy, connections_fail_hard=True, values_fail_hard=False - ) + copier.copy_io(to_copy, connections_fail_hard=True, values_fail_hard=False) for copier_panel, copied_panel in zip( copier._owned_io_panels, to_copy._owned_io_panels, strict=False ): @@ -438,28 +426,29 @@ def test_copy_io(self): self.assertListEqual( copier_channel.connections, copied_channel.connections, - msg="All connections on shared channels should copy" + msg="All connections on shared channels should copy", ) if isinstance(copier_channel, DataChannel): self.assertEqual( copier_channel.value, copied_channel.value, - msg="All values on shared channels should copy" + msg="All values on shared channels should copy", ) except AttributeError: # We only need to check shared channels pass - with self.subTest("Force failure on value copy fail"),self.assertRaises( + with ( + self.subTest("Force failure on value copy fail"), + self.assertRaises( ValueCopyError, msg="The copier doesn't have channels to hold all the values that need" - "copying, so we should fail" + "copying, so we should fail", + ), ): - copier.copy_io( - to_copy, connections_fail_hard=True, values_fail_hard=True - ) + copier.copy_io(to_copy, connections_fail_hard=True, values_fail_hard=True) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index dfc2532c..df004cbc 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -20,7 +20,9 @@ class ANode(Node): """To de-abstract the class""" def _setup_node(self) -> None: - self._inputs = Inputs(InputData("x", self, type_hint=int),) + self._inputs = Inputs( + InputData("x", self, type_hint=int), + ) self._outputs = OutputsWithInjection( OutputDataWithInjection("y", self, type_hint=int), ) @@ -60,7 +62,7 @@ def test_set_input_values(self): self.assertEqual( 2, n.inputs.x.value, - msg="Post-instantiation update of inputs should also work" + msg="Post-instantiation update of inputs should also work", ) with self.assertRaises(ValueError, msg="Non-input-channel kwargs not allowed"): @@ -77,28 +79,27 @@ def test_set_input_values(self): self.assertEqual( "not an int", n.inputs.x.value, - msg="It should be possible to deactivate type checking from the node level" + msg="It should be possible to deactivate type checking from the node level", ) def test_run_data_tree(self): self.assertEqual( add_one(add_one(add_one(self.n1.inputs.x.value))), self.n3.run(run_data_tree=True), - msg="Should pull start down to end, even with no flow defined" + msg="Should pull start down to end, even with no flow defined", ) def test_fetch_input(self): self.n1.outputs.y.value = 0 with self.assertRaises( - ValueError, - msg="Without input, we should not achieve readiness" + ValueError, msg="Without input, we should not achieve readiness" ): self.n2.run(run_data_tree=False, fetch_input=False, check_readiness=True) self.assertEqual( add_one(self.n1.outputs.y.value), self.n2.run(run_data_tree=False, fetch_input=True), - msg="After fetching the upstream data, should run fine" + msg="After fetching the upstream data, should run fine", ) def test_check_readiness(self): @@ -108,32 +109,31 @@ def test_check_readiness(self): # but don't care about generating a file with self.assertRaises( - ValueError, - msg="When input is not data, we should fail early" + ValueError, msg="When input is not data, we should fail early" ): self.n3.run(run_data_tree=False, fetch_input=False, check_readiness=True) self.assertFalse( self.n3.failed, msg="The benefit of the readiness check should be that we don't actually " - "qualify as failed" + "qualify as failed", ) with self.assertRaises( TypeError, - msg="If we bypass the check, we should get the failing function error" + msg="If we bypass the check, we should get the failing function error", ): self.n3.run(run_data_tree=False, fetch_input=False, check_readiness=False) self.assertTrue( self.n3.failed, - msg="If the node operation itself fails, the status should be failed" + msg="If the node operation itself fails, the status should be failed", ) self.n3.inputs.x = 0 with self.assertRaises( ValueError, - msg="When status is failed, we should fail early, even if input data is ok" + msg="When status is failed, we should fail early, even if input data is ok", ): self.n3.run(run_data_tree=False, fetch_input=False, check_readiness=True) @@ -142,7 +142,7 @@ def test_check_readiness(self): 1, self.n3.run(run_data_tree=False, fetch_input=False, check_readiness=True), msg="After manually resetting the failed state and providing good input, " - "running should proceed" + "running should proceed", ) self.n3.use_cache = n3_cache @@ -153,7 +153,7 @@ def test_emit_ran_signal(self): self.n1.run(emit_ran_signal=False) self.assertFalse( self.n3.inputs.x.ready, - msg="Without emitting the ran signal, nothing should happen downstream" + msg="Without emitting the ran signal, nothing should happen downstream", ) self.n1.run(emit_ran_signal=True) @@ -161,7 +161,7 @@ def test_emit_ran_signal(self): add_one(add_one(add_one(self.n1.inputs.x.value))), self.n3.outputs.y.value, msg="With the connection and signal, we should have pushed downstream " - "execution" + "execution", ) def test_failure_signal(self): @@ -182,11 +182,7 @@ def add(self, signal): except TypeError: # Expected -- we're _trying_ to get failure to fire n.delete_storage(filename=n.as_path().joinpath("recovery")) - self.assertEqual( - c.count, - 1, - msg="Failed signal should fire after type error" - ) + self.assertEqual(c.count, 1, msg="Failed signal should fire after type error") def test_failure_recovery(self): n = ANode(label="failing") @@ -198,7 +194,7 @@ def test_failure_recovery(self): self.assertFalse( n.as_path().exists(), msg="When the run exception is not raised, we don't expect any " - "recovery file to be needed" + "recovery file to be needed", ) default_recovery = n.recovery @@ -209,7 +205,7 @@ def test_failure_recovery(self): self.assertFalse( n.has_saved_content(filename=n.as_path().joinpath("recovery")), msg="Without a recovery back end specified, we don't expect a file to " - "be saved on failure." + "be saved on failure.", ) n.recovery = default_recovery @@ -218,25 +214,24 @@ def test_failure_recovery(self): n.run(check_readiness=False) self.assertTrue( n.has_saved_content(filename=n.as_path().joinpath("recovery")), - msg="Expect a recovery file to be saved on failure" + msg="Expect a recovery file to be saved on failure", ) reloaded = ANode(label="failing", autoload=True) self.assertIs( reloaded.inputs.x.value, NOT_DATA, - msg="We don't anticipate _auto_ loading from recovery files" + msg="We don't anticipate _auto_ loading from recovery files", ) self.assertFalse(reloaded.failed, msg="Sanity check") reloaded.load(filename=reloaded.as_path().joinpath("recovery")) self.assertTrue( - reloaded.failed, - msg="Expect to have reloaded the failed node." + reloaded.failed, msg="Expect to have reloaded the failed node." ) self.assertEqual( reloaded.inputs.x.value, n.inputs.x.value, - msg="Expect data to have been reloaded from the failed node" + msg="Expect data to have been reloaded from the failed node", ) finally: @@ -244,7 +239,7 @@ def test_failure_recovery(self): self.assertFalse( n.as_path().exists(), msg="The recovery file should have been the only thing in the node " - "directory, so cleaning should remove the directory entirely." + "directory, so cleaning should remove the directory entirely.", ) def test_execute(self): @@ -253,12 +248,12 @@ def test_execute(self): self.assertEqual( self.n2.run(fetch_input=False, emit_ran_signal=False, x=10) + 1, self.n2.execute(x=11), - msg="Execute should _not_ fetch in the upstream data" + msg="Execute should _not_ fetch in the upstream data", ) self.assertFalse( self.n3.ready, msg="Executing should not be triggering downstream runs, even though we " - "made a ran/run connection" + "made a ran/run connection", ) self.n2.inputs.x._value = "manually override the desired int" @@ -266,7 +261,7 @@ def test_execute(self): with self.assertRaises( TypeError, msg="Execute should be running without a readiness check and hitting the " - "string + int error" + "string + int error", ): self.n2.execute() @@ -274,20 +269,16 @@ def test_pull(self): self.n2 >> self.n3 self.n1.inputs.x = 0 by_run = self.n2.run( - run_data_tree=True, - fetch_input=True, - emit_ran_signal=False - ) + run_data_tree=True, fetch_input=True, emit_ran_signal=False + ) self.n1.inputs.x = 1 self.assertEqual( - by_run + 1, - self.n2.pull(), - msg="Pull should be running the upstream node" + by_run + 1, self.n2.pull(), msg="Pull should be running the upstream node" ) self.assertFalse( self.n3.ready, msg="Pulling should not be triggering downstream runs, even though we " - "made a ran/run connection" + "made a ran/run connection", ) def test___call__(self): @@ -296,20 +287,16 @@ def test___call__(self): self.n2 >> self.n3 self.n1.inputs.x = 0 by_run = self.n2.run( - run_data_tree=True, - fetch_input=True, - emit_ran_signal=False + run_data_tree=True, fetch_input=True, emit_ran_signal=False ) self.n1.inputs.x = 1 self.assertEqual( - by_run + 1, - self.n2(), - msg="A call should be running the upstream node" + by_run + 1, self.n2(), msg="A call should be running the upstream node" ) self.assertFalse( self.n3.ready, msg="Calling should not be triggering downstream runs, even though we " - "made a ran/run connection" + "made a ran/run connection", ) def test_draw(self): @@ -324,11 +311,9 @@ def test_draw(self): # That name is just an implementation detail, update it as # needed self.assertTrue( - self.n1.as_path().joinpath( - expected_name - ).is_file(), + self.n1.as_path().joinpath(expected_name).is_file(), msg="If `save` is called, expect the rendered image to " - "exist in the working directory" + "exist in the working directory", ) user_specified_name = "foo" @@ -337,7 +322,7 @@ def test_draw(self): self.assertTrue( self.n1.as_path().joinpath(expected_name).is_file(), msg="If the user specifies a filename, we should assume they want the " - "thing saved" + "thing saved", ) finally: # No matter what happens in the tests, clean up after yourself @@ -350,12 +335,12 @@ def test_autorun(self): self.assertIs( self.n1.outputs.y.value, NOT_DATA, - msg="By default, nodes should not be getting run until asked" + msg="By default, nodes should not be getting run until asked", ) self.assertEqual( 1, ANode(label="right_away", autorun=True, x=0).outputs.y.value, - msg="With autorun, the node should run right away" + msg="With autorun, the node should run right away", ) def test_graph_info(self): @@ -365,14 +350,14 @@ def test_graph_info(self): n.semantic_delimiter + n.label, n.graph_path, msg="Lone nodes should just have their label as the path, as there is no " - "parent above." + "parent above.", ) self.assertIs( n, n.graph_root, msg="Lone nodes should be their own graph_root, as there is no parent " - "above." + "above.", ) def test_single_value(self): @@ -381,7 +366,7 @@ def test_single_value(self): node.outputs.y, node.channel, msg="With a single output, the `HasChannel` interface fulfillment should " - "use that output." + "use that output.", ) with_addition = node + 5 @@ -389,7 +374,7 @@ def test_single_value(self): with_addition, Node, msg="With a single output, acting on the node should fall back on acting " - "on the single (with-injection) output" + "on the single (with-injection) output", ) node2 = ANode(label="n2") @@ -398,22 +383,20 @@ def test_single_value(self): [node.outputs.y], node2.inputs.x.connections, msg="With a single output, the node should fall back on the single output " - "for output-like use cases" + "for output-like use cases", ) node.outputs["z"] = OutputDataWithInjection("z", node, type_hint=int) with self.assertRaises( AmbiguousOutputError, msg="With multiple outputs, trying to exploit the `HasChannel` interface " - "should fail cleanly" + "should fail cleanly", ): node.channel # noqa: B018 def test_storage(self): self.assertIs( - self.n1.outputs.y.value, - NOT_DATA, - msg="Sanity check on initial state" + self.n1.outputs.y.value, NOT_DATA, msg="Sanity check on initial state" ) y = self.n1() @@ -421,7 +404,7 @@ def test_storage(self): with self.assertRaises( FileNotFoundError, - msg="We just verified there is no save file, so loading should fail." + msg="We just verified there is no save file, so loading should fail.", ): self.n1.load() @@ -436,14 +419,16 @@ def test_storage(self): self.assertEqual( y, reloaded.outputs.y.value, - msg="Nodes should load by default if they find a save file" + msg="Nodes should load by default if they find a save file", ) - clean_slate = ANode(label=self.n1.label, x=x, delete_existing_savefiles=True) + clean_slate = ANode( + label=self.n1.label, x=x, delete_existing_savefiles=True + ) self.assertIs( clean_slate.outputs.y.value, NOT_DATA, - msg="Users should be able to ignore a save" + msg="Users should be able to ignore a save", ) run_right_away = ANode( @@ -454,44 +439,39 @@ def test_storage(self): self.assertEqual( y, run_right_away.outputs.y.value, - msg="With nothing to load, running after init is fine" + msg="With nothing to load, running after init is fine", ) run_right_away.save() load_and_rerun_origal_input = ANode( - label=self.n1.label, - autorun=True, - autoload=backend + label=self.n1.label, autorun=True, autoload=backend ) self.assertEqual( load_and_rerun_origal_input.outputs.y.value, run_right_away.outputs.y.value, msg="Loading and then running immediately is fine, and should " - "recover existing input" + "recover existing input", ) load_and_rerun_new_input = ANode( - label=self.n1.label, - x=x + 1, - autorun=True, - autoload=backend + label=self.n1.label, x=x + 1, autorun=True, autoload=backend ) self.assertEqual( load_and_rerun_new_input.outputs.y.value, run_right_away.outputs.y.value + 1, msg="Loading and then running immediately is fine, and should " - "notice the new input" + "notice the new input", ) force_run = ANode( label=self.n1.label, x=x, autorun=True, - delete_existing_savefiles=True + delete_existing_savefiles=True, ) self.assertEqual( y, force_run.outputs.y.value, - msg="Destroying the save should allow immediate re-running" + msg="Destroying the save should allow immediate re-running", ) hard_input = ANode(label="hard") @@ -499,19 +479,16 @@ def test_storage(self): hard_input.inputs.x = lambda x: x * 2 if isinstance(backend, PickleStorage): hard_input.save() - reloaded = ANode( - label=hard_input.label, - autoload=backend - ) + reloaded = ANode(label=hard_input.label, autoload=backend) self.assertEqual( reloaded.inputs.x.value(4), hard_input.inputs.x.value(4), - msg="Cloud pickle should be strong enough to recover this" + msg="Cloud pickle should be strong enough to recover this", ) else: with self.assertRaises( (TypeError, AttributeError), - msg="Other backends are not powerful enough for some values" + msg="Other backends are not powerful enough for some values", ): hard_input.save() finally: @@ -528,11 +505,11 @@ def test_storage_to_filename(self): self.n1.save(backend=backend, filename=fname) self.assertFalse( self.n1.has_saved_content(backend=backend), - msg="There should be no content at the default location" + msg="There should be no content at the default location", ) self.assertTrue( self.n1.has_saved_content(backend=backend, filename=fname), - msg="There should be content at the specified file location" + msg="There should be content at the specified file location", ) new = ANode() new.load(filename=fname) @@ -542,7 +519,7 @@ def test_storage_to_filename(self): self.n1.delete_storage(backend=backend, filename=fname) self.assertFalse( self.n1.has_saved_content(backend=backend, filename=fname), - msg="Deleting storage should have cleaned up the file" + msg="Deleting storage should have cleaned up the file", ) def test_checkpoint(self): @@ -567,7 +544,7 @@ def test_checkpoint(self): NOT_DATA, not_reloaded.outputs.y.value, msg="Should not have saved, therefore should have been nothing " - "to load" + "to load", ) find_saved = ANode(label="run_and_save", autoload=backend) @@ -575,7 +552,7 @@ def test_checkpoint(self): y, find_saved.outputs.y.value, msg="Should have saved automatically after run, and reloaded " - "on instantiation" + "on instantiation", ) finally: saves.delete_storage(backend) # Clean up @@ -594,7 +571,7 @@ def test_result_serialization(self): out = n() self.assertTrue( n._temporary_result_file.is_file(), - msg="Sanity check that we've saved the output" + msg="Sanity check that we've saved the output", ) # Now fake it n.running = True @@ -607,9 +584,9 @@ def test_result_serialization(self): self.assertFalse( n.as_path().is_dir(), msg="Actually, we expect cleanup to have removed empty directories up to " - "and including the node's own directory" + "and including the node's own directory", ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_output_parser.py b/tests/unit/test_output_parser.py index f1a96bd5..aecfcd2a 100644 --- a/tests/unit/test_output_parser.py +++ b/tests/unit/test_output_parser.py @@ -7,68 +7,82 @@ class TestParseOutput(unittest.TestCase): def test_parsing(self): with self.subTest("Single return"): + def identity(x): return x + self.assertListEqual(ParseOutput(identity).output, ["x"]) with self.subTest("Expression return"): + def add(x, y): return x + y + self.assertListEqual(ParseOutput(add).output, ["x + y"]) with self.subTest("Weird whitespace"): + def add_with_whitespace(x, y): - return x + y + return x + y + self.assertListEqual(ParseOutput(add_with_whitespace).output, ["x + y"]) with self.subTest("Multiple expressions"): + def add_and_subtract(x, y): return x + y, x - y + self.assertListEqual( - ParseOutput(add_and_subtract).output, - ["x + y", "x - y"] + ParseOutput(add_and_subtract).output, ["x + y", "x - y"] ) with self.subTest("Best-practice (well-named return vars)"): + def md(job): temperature = job.output.temperature energy = job.output.energy return temperature, energy + self.assertListEqual(ParseOutput(md).output, ["temperature", "energy"]) with self.subTest("Function call returns"): + def function_return(i, j): - return ( - math.log( - 10, base=2 - ), - math.atan2(1, 2) - ) + return (math.log(10, base=2), math.atan2(1, 2)) + self.assertListEqual( ParseOutput(function_return).output, - ["math.log( 10, base=2 )", "math.atan2(1, 2)"] + ["math.log(10, base=2)", "math.atan2(1, 2)"], ) with self.subTest("Methods too"): + class Foo: def add(self, x, y): return x + y + self.assertListEqual(ParseOutput(Foo.add).output, ["x + y"]) def test_void(self): with self.subTest("No return"): + def no_return(): pass + self.assertIsNone(ParseOutput(no_return).output) with self.subTest("Empty return"): + def empty_return(): return + self.assertIsNone(ParseOutput(empty_return).output) with self.subTest("Return None explicitly"): + def none_return(): return None + self.assertIsNone(ParseOutput(none_return).output) def test_multiple_branches(self): @@ -77,9 +91,10 @@ def bifurcating(x): return True else: return False + with self.assertRaises(ValueError): ParseOutput(bifurcating) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_pyiron_workflow.py b/tests/unit/test_pyiron_workflow.py index f3322272..8c076e6d 100644 --- a/tests/unit/test_pyiron_workflow.py +++ b/tests/unit/test_pyiron_workflow.py @@ -8,5 +8,5 @@ def test_single_point_of_entry(self): # level -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py index 4e16413a..f6c26d32 100644 --- a/tests/unit/test_storage.py +++ b/tests/unit/test_storage.py @@ -8,7 +8,6 @@ class TestAvailableBackends(unittest.TestCase): - def test_default_backend(self): backends = list(available_backends()) self.assertIsInstance( @@ -16,7 +15,7 @@ def test_default_backend(self): PickleStorage, msg="If more standard backends are added, this will fail -- that's fine, " "just update the test to make sure you're getting the defaults you now " - "expect." + "expect.", ) def test_specific_backend(self): @@ -26,7 +25,7 @@ def test_specific_backend(self): 1, msg="Once more standard backends are available, we should test that string " "access results in the the correct priority assignment among these " - "defaults." + "defaults.", ) self.assertIsInstance(backends[0], PickleStorage) @@ -34,15 +33,11 @@ def test_extra_backend(self): my_interface = PickleStorage() backends = list(available_backends(my_interface)) self.assertEqual( - len(backends), - 2, - msg="We expect both the one we passed, and all defaults" + len(backends), 2, msg="We expect both the one we passed, and all defaults" ) self.assertIs(backends[0], my_interface) self.assertIsNot( - backends[0], - backends[1], - msg="They should be separate instances" + backends[0], backends[1], msg="They should be separate instances" ) def test_exclusive_backend(self): @@ -51,13 +46,12 @@ def test_exclusive_backend(self): self.assertEqual( len(backends), 1, - msg="We expect to filter out everything except the one we asked for" + msg="We expect to filter out everything except the one we asked for", ) self.assertIs(backends[0], my_interface) class TestStorage(unittest.TestCase): - def setUp(self): self.node = UserInput(label="test_node") self.storage = PickleStorage() @@ -94,7 +88,7 @@ def test_input_validity(self): for method in [ self.storage.load, self.storage.has_saved_content, - self.storage.delete + self.storage.delete, ]: with self.subTest(method.__name__): with self.assertRaises(ValueError): @@ -118,7 +112,7 @@ def Unimportable(x): interface = PickleStorage(cloudpickle_fallback=False) with self.assertRaises( TypeNotFoundError, - msg="We can't import from , so this is unpicklable" + msg="We can't import from , so this is unpicklable", ): interface.save(u) @@ -134,4 +128,3 @@ def Unimportable(x): if __name__ == "__main__": unittest.main() - diff --git a/tests/unit/test_type_hinting.py b/tests/unit/test_type_hinting.py index fdaa2a91..76f84bc6 100644 --- a/tests/unit/test_type_hinting.py +++ b/tests/unit/test_type_hinting.py @@ -21,19 +21,19 @@ def __call__(self): ureg = UnitRegistry() for hint, good, bad in ( - (int | float, 1, "foo"), - (int | float, 2.0, "bar"), - (typing.Literal[1, 2], 2, 3), - (typing.Literal[1, 2], 1, "baz"), - (Foo, Foo(), Foo), - (type[Bar], Bar, Bar()), - # (callable, Bar(), Foo()), # Misses the bad! - # Can't hint args and returns without typing.Callable anyhow, so that's - # what people should be using regardless - (typing.Callable, Bar(), Foo()), - (tuple[int, float], (1, 1.1), ("fo", 0)), - (dict[str, int], {'a': 1}, {'a': 'b'}), - (int, 1 * ureg.seconds, 1.0 * ureg.seconds) # Disregard unit, look@type + (int | float, 1, "foo"), + (int | float, 2.0, "bar"), + (typing.Literal[1, 2], 2, 3), + (typing.Literal[1, 2], 1, "baz"), + (Foo, Foo(), Foo), + (type[Bar], Bar, Bar()), + # (callable, Bar(), Foo()), # Misses the bad! + # Can't hint args and returns without typing.Callable anyhow, so that's + # what people should be using regardless + (typing.Callable, Bar(), Foo()), + (tuple[int, float], (1, 1.1), ("fo", 0)), + (dict[str, int], {"a": 1}, {"a": "b"}), + (int, 1 * ureg.seconds, 1.0 * ureg.seconds), # Disregard unit, look@type ): with self.subTest(msg=f"Good {good} vs hint {hint}"): self.assertTrue(valid_value(good, hint)) @@ -63,29 +63,29 @@ def test_hint_comparisons(self): (dict[int, str], dict[str, int], False), (typing.Callable[[int, float], None], typing.Callable, True), ( - typing.Callable[[int, float], None], - typing.Callable[[float, int], None], - False + typing.Callable[[int, float], None], + typing.Callable[[float, int], None], + False, ), ( - typing.Callable[[int, float], float], - typing.Callable[[int, float], float | str], - True + typing.Callable[[int, float], float], + typing.Callable[[int, float], float | str], + True, ), ( - typing.Callable[[int, float, str], float], - typing.Callable[[int, float], float], - False + typing.Callable[[int, float, str], float], + typing.Callable[[int, float], float], + False, ), ]: with self.subTest( - target=target, reference=reference, expected=is_more_specific + target=target, reference=reference, expected=is_more_specific ): self.assertEqual( type_hint_is_as_or_more_specific_than(target, reference), - is_more_specific + is_more_specific, ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 6a7a952c..e74dcd26 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -27,7 +27,7 @@ def PlusOne(x: int = 0): @Workflow.wrap.as_function_node -def five(sleep_time=0.): +def five(sleep_time=0.0): sleep(sleep_time) five = 5 return five @@ -39,7 +39,6 @@ def sum(a, b): class TestWorkflow(unittest.TestCase): - def test_io(self): wf = Workflow("wf") wf.n1 = wf.create.function_node(plus_one) @@ -48,9 +47,7 @@ def test_io(self): inp = wf.inputs inp_again = wf.inputs - self.assertIsNot( - inp, inp_again, msg="Workflow input should always get rebuilt" - ) + self.assertIsNot(inp, inp_again, msg="Workflow input should always get rebuilt") n_in = len(wf.inputs) n_out = len(wf.outputs) @@ -67,7 +64,7 @@ def test_io(self): wf.n3.inputs.x = wf.n2.outputs.y wf.n2.inputs.x = wf.n1.outputs.y self.assertEqual( - n_in -2, len(wf.inputs), msg="New connections should get reflected" + n_in - 2, len(wf.inputs), msg="New connections should get reflected" ) self.assertEqual( n_out - 2, len(wf.outputs), msg="New connections should get reflected" @@ -84,7 +81,7 @@ def test_io(self): wf.n2.outputs.y, wf.outputs.intermediate, msg="IO should be by reference" ) self.assertNotIn(wf.n3.outputs.y, wf.outputs, msg="IO should be hidable") - + def test_io_maps(self): # input and output, renaming, accessing connected, and deactivating disconnected wf = Workflow("wf") @@ -113,36 +110,30 @@ def test_io_maps(self): wf.set_run_signals_to_dag_execution() out = wf.run() self.assertEqual( - 3, - out.y, - msg="New names should be propagated to the returned value" + 3, out.y, msg="New names should be propagated to the returned value" ) self.assertNotIn( "m__y", list(out.keys()), - msg="IO filtering should be evident in returned value" + msg="IO filtering should be evident in returned value", ) self.assertEqual( 43, wf.m.outputs.y.value, - msg="The child channel should still exist and have run" + msg="The child channel should still exist and have run", ) self.assertEqual( - 1, - wf.inputs.intermediate_x.value, - msg="IO should be up-to-date post-run" + 1, wf.inputs.intermediate_x.value, msg="IO should be up-to-date post-run" ) self.assertEqual( - 2, - wf.outputs.intermediate_y.value, - msg="IO should be up-to-date post-run" + 2, wf.outputs.intermediate_y.value, msg="IO should be up-to-date post-run" ) def test_io_map_bijectivity(self): wf = Workflow("wf") with self.assertRaises( ValueDuplicationError, - msg="Should not be allowed to map two children's channels to the same label" + msg="Should not be allowed to map two children's channels to the same label", ): wf.inputs_map = {"n1__x": "x", "n2__x": "x"} @@ -150,7 +141,7 @@ def test_io_map_bijectivity(self): with self.assertRaises( ValueDuplicationError, msg="Should not be allowed to update a second child's channel onto an " - "existing mapped channel" + "existing mapped channel", ): wf.inputs_map["n2__x"] = "x" @@ -161,35 +152,25 @@ def test_io_map_bijectivity(self): wf.inputs_map["n1__x"] = None wf.inputs_map["n2__x"] = None wf.inputs_map["n3__x"] = None - self.assertEqual( - 3, - len(wf.inputs_map), - msg="All entries should be stored" - ) - self.assertEqual( - 0, - len(wf.inputs), - msg="No IO should be left exposed" - ) + self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") + self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") def test_is_parentmost(self): wf = Workflow("wf") wf2 = Workflow("wf2") with self.assertRaises( - ParentMostError, - msg="Workflows are promised in the docs to be parent-most" + ParentMostError, msg="Workflows are promised in the docs to be parent-most" ): wf.parent = wf2 with self.assertRaises( ParentMostError, - msg="We want to catch parent-most failures early when assigning children" + msg="We want to catch parent-most failures early when assigning children", ): wf.sub_wf = wf2 def test_with_executor(self): - wf = Workflow("wf") wf.a = wf.create.function_node(plus_one) wf.b = wf.create.function_node(plus_one, x=wf.a) @@ -200,38 +181,34 @@ def test_with_executor(self): self.assertIs( NOT_DATA, wf.outputs.b__y.value, - msg="Sanity check that test is in right starting condition" + msg="Sanity check that test is in right starting condition", ) result = wf(a__x=0) self.assertIsInstance( - result, - Future, - msg="Should be running as a parallel process" + result, Future, msg="Should be running as a parallel process" ) returned_nodes = result.result(timeout=120) # Wait for the process to finish self.assertIsNot( original_a, returned_nodes.a, - msg="Executing in a parallel process should be returning new instances" + msg="Executing in a parallel process should be returning new instances", ) self.assertIs( - wf, - wf.a.parent, - msg="Returned nodes should get the macro as their parent" + wf, wf.a.parent, msg="Returned nodes should get the macro as their parent" ) self.assertIsNone( original_a.parent, msg=f"Original nodes should be orphaned, but {original_a.full_label} has " - f"parent {original_a.parent}" + f"parent {original_a.parent}", # Note: At time of writing, this is accomplished in Node.__getstate__, # which feels a bit dangerous... ) self.assertEqual( 0 + 1 + 1, wf.outputs.b__y.value, - msg="And of course we expect the calculation to actually run" + msg="And of course we expect the calculation to actually run", ) wf.executor_shutdown() @@ -246,33 +223,27 @@ def test_parallel_execution(self): wf.slow.run() wf.fast.run() - self.assertTrue( - wf.slow.running, - msg="The slow node should still be running" - ) + self.assertTrue(wf.slow.running, msg="The slow node should still be running") self.assertEqual( wf.fast.outputs.five.value, 5, - msg="The slow node should not prohibit the completion of the fast node" + msg="The slow node should not prohibit the completion of the fast node", ) self.assertEqual( wf.sum.outputs.sum.value, NOT_DATA, - msg="The slow node _should_ hold up the downstream node to which it inputs" + msg="The slow node _should_ hold up the downstream node to which it inputs", ) wf.slow.future.result(timeout=120) # Wait for it to finish - self.assertFalse( - wf.slow.running, - msg="The slow node should be done running" - ) + self.assertFalse(wf.slow.running, msg="The slow node should be done running") wf.sum.run() self.assertEqual( wf.sum.outputs.sum.value, 5 + 5, msg="After the slow node completes, its output should be updated as a " - "callback, and downstream nodes should proceed" + "callback, and downstream nodes should proceed", ) wf.executor_shutdown() @@ -292,14 +263,14 @@ def sum_(a, b): self.assertEqual( wf.a.outputs.y.value + wf.b.outputs.y.value, wf.sum.outputs.sum.value, - msg="Sanity check" + msg="Sanity check", ) wf(a__x=42, b__x=42) self.assertEqual( plus_one(42) + plus_one(42), wf.sum.outputs.sum.value, msg="Workflow should accept input channel kwargs and update inputs " - "accordingly" + "accordingly", # Since the nodes run automatically, there is no need for wf.run() here ) @@ -319,7 +290,7 @@ def test_return_value(self): return_on_call, DotDict({"b__y": 1 + 2}), msg="Run output should be returned on call. Expecting a DotDict of " - "output values" + "output values", ) wf.inputs.a__x = 2 @@ -328,7 +299,7 @@ def test_return_value(self): return_on_explicit_run["b__y"], 2 + 2, msg="On explicit run, the most recent input data should be used and " - "the result should be returned" + "the result should be returned", ) def test_execution_automation(self): @@ -346,13 +317,12 @@ def make_workflow(): return wf def matches_expectations(results): - expected = {'n2l__out': -9, 'n2m__out': 3, 'n2r__out': 12} + expected = {"n2l__out": -9, "n2m__out": 3, "n2r__out": 12} return all(expected[k] == v for k, v in results.items()) auto = make_workflow() self.assertTrue( - matches_expectations(auto()), - msg="DAGs should run automatically" + matches_expectations(auto()), msg="DAGs should run automatically" ) user = make_workflow() @@ -363,29 +333,29 @@ def matches_expectations(results): user.starting_nodes = [user.n1l] self.assertTrue( matches_expectations(user()), - msg="Users shoudl be allowed to ask to run things manually" + msg="Users shoudl be allowed to ask to run things manually", ) self.assertIn( user.n1r.signals.output.ran, user.n2r.signals.input.run.connections, - msg="Expected execution signals as manually defined" + msg="Expected execution signals as manually defined", ) user.automate_execution = True self.assertTrue( matches_expectations(user()), - msg="Users should be able to switch back to automatic execution" + msg="Users should be able to switch back to automatic execution", ) self.assertNotIn( user.n1r.signals.output.ran, user.n2r.signals.input.run.connections, - msg="Expected old execution signals to be overwritten" + msg="Expected old execution signals to be overwritten", ) self.assertIn( user.n1r.signals.output.ran, user.n2r.signals.input.accumulate_and_run.connections, msg="The automated flow uses a non-linear accumulating approach, so the " - "accumulating run signal is the one that should hold a connection" + "accumulating run signal is the one that should hold a connection", ) with self.subTest("Make sure automated cyclic graphs throw an error"): @@ -415,13 +385,12 @@ def add_three_macro(self, one__x): self.assertEqual( (0 + 1) + (1 + 1), wf.m.two.pull(run_parent_trees_too=True), - msg="Sanity check, pulling here should work perfectly fine" + msg="Sanity check, pulling here should work perfectly fine", ) wf.m.one.executor = wf.create.ProcessPoolExecutor() with self.assertRaises( - ValueError, - msg="Should not be able to pull with executor in local scope" + ValueError, msg="Should not be able to pull with executor in local scope" ): wf.m.two.pull() wf.m.one.executor_shutdown() # Shouldn't get this far, but if so, shutdown @@ -429,8 +398,7 @@ def add_three_macro(self, one__x): wf.n1.executor = wf.create.ProcessPoolExecutor() with self.assertRaises( - ValueError, - msg="Should not be able to pull with executor in parent scope" + ValueError, msg="Should not be able to pull with executor in parent scope" ): wf.m.two.pull(run_parent_trees_too=True) @@ -454,12 +422,12 @@ def test_storage_values(self): self.assertEqual( wf_out.out__add, reloaded.outputs.out__add.value, - msg="Workflow-level data should get reloaded" + msg="Workflow-level data should get reloaded", ) self.assertEqual( three_result, reloaded.inp.three.value, - msg="Child data arbitrarily deep should get reloaded" + msg="Child data arbitrarily deep should get reloaded", ) finally: # Clean up after ourselves @@ -483,10 +451,13 @@ def test_storage_scopes(self): for backend in available_backends(): try: wf.import_type_mismatch = demo_nodes.Dynamic() - with self.subTest(backend), self.assertRaises( - TypeNotFoundError, - msg="Imported object is function but node type is node " - "-- should fail early on save" + with ( + self.subTest(backend), + self.assertRaises( + TypeNotFoundError, + msg="Imported object is function but node type is node " + "-- should fail early on save", + ), ): wf.save(backend=backend, cloudpickle_fallback=False) finally: @@ -494,6 +465,7 @@ def test_storage_scopes(self): wf.delete_storage(backend) with self.subTest("Unimportable node"): + @Workflow.wrap.as_function_node("y") def UnimportableScope(x): return x @@ -507,11 +479,9 @@ def test_pickle(self): wf_out = wf() reloaded = pickle.loads(pickle.dumps(wf)) self.assertDictEqual( - wf_out, - reloaded.outputs.to_value_dict(), - msg="Pickling should work" + wf_out, reloaded.outputs.to_value_dict(), msg="Pickling should work" ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()