diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 80084ba5..4a955753 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -25,6 +25,8 @@ from pyiron_workflow.util import SeabornColors if TYPE_CHECKING: + from pathlib import Path + import graphviz from pyiron_workflow.channels import Channel @@ -587,10 +589,26 @@ def color(self) -> str: return SeabornColors.white def draw( - self, depth: int = 1, rankdir: Literal["LR", "TB"] = "LR" + self, + depth: int = 1, + rankdir: Literal["LR", "TB"] = "LR", + save: bool = False, + view: bool = False, + directory: Optional[Path | str] = None, + filename: Optional[Path | str] = None, + format: Optional[str] = None, + cleanup: bool = True, ) -> graphviz.graphs.Digraph: """ - Draw the node structure. + Draw the node structure and return it as a graphviz object. + + A selection of the `graphviz.Graph.render` method options are exposed, and if + `view` or `filename` is provided, this will be called before returning the + graph. + The graph file and rendered image will be stored in the node's working + directory. + This is purely for convenience -- since we directly return a graphviz object + you can instead use this to leverage the full power of graphviz. Args: depth (int): How deeply to decompose the representation of composite nodes @@ -600,17 +618,37 @@ def draw( max depth of the node will have no adverse side effects. rankdir ("LR" | "TB"): Use left-right or top-bottom graphviz `rankdir` to orient the flow of the graph. + save (bool): Render the graph image. (Default is False. When True, all + other defaults will yield a PDF in the node's working directory.) + view (bool): `graphviz.Graph.render` argument, open the rendered result + with the default application. (Default is False. When True, default + values for the directory and filename are supplied by the node working + directory and label.) + directory (Path|str|None): `graphviz.Graph.render` argument, (sub)directory + for source saving and rendering. (Default is None, which uses the + node's working directory.) + filename (Path|str): `graphviz.Graph.render` argument, filename for saving + the source. (Default is None, which uses the node label + `"_graph"`. + format (str|None): `graphviz.Graph.render` argument, the output format used + for rendering ('pdf', 'png', etc.). + cleanup (bool): `graphviz.Graph.render` argument, delete the source file + after successful rendering. (Default is True -- unlike graphviz.) Returns: (graphviz.graphs.Digraph): The resulting graph object. - - Note: - The graphviz docs will elucidate all the possibilities of what to do with - the returned object, but the thing you are most likely to need is the - `render` method, which allows you to save the resulting graph as an image. - E.g. `self.draw().render(filename="my_node", format="png")`. """ - return GraphvizNode(self, depth=depth, rankdir=rankdir).graph + graph = GraphvizNode(self, depth=depth, rankdir=rankdir).graph + if save or view or filename is not None: + directory = self.working_directory.path if directory is None else directory + filename = self.label + "_graph" if filename is None else filename + graph.render( + view=view, + directory=directory, + filename=filename, + format=format, + cleanup=cleanup, + ) + return graph def activate_strict_hints(self): """Enable type hint checks for all data IO""" diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index 5c18e8d6..daedcab6 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -305,3 +305,33 @@ def test_working_directory(self): msg="Just want to make sure we cleaned up after ourselves" ) + def test_draw(self): + try: + self.n1.draw() + self.assertFalse( + any(self.n1.working_directory.path.iterdir()) + ) + + fmt = "pdf" # This is just so we concretely know the filename suffix + self.n1.draw(save=True, format=fmt) + expected_name = self.n1.label + "_graph." + fmt + # That name is just an implementation detail, update it as needed + self.assertTrue( + self.n1.working_directory.path.joinpath(expected_name).is_file(), + msg="If `save` is called, expect the rendered image to exist in the working" + "directory" + ) + + user_specified_name = "foo" + self.n1.draw(filename=user_specified_name, format=fmt) + expected_name = user_specified_name + "." + fmt + self.assertTrue( + self.n1.working_directory.path.joinpath(expected_name).is_file(), + msg="If the user specifies a filename, we should assume they want the " + "thing saved" + ) + finally: + # No matter what happens in the tests, clean up after yourself + self.n1.working_directory.delete() + +