diff --git a/diagrams/Diagram.py b/diagrams/Diagram.py index 446f1807b..135d4ea87 100644 --- a/diagrams/Diagram.py +++ b/diagrams/Diagram.py @@ -3,9 +3,10 @@ from .Context import Context from .utils import setdiagram + class Diagram(Context): __curvestyles = ("ortho", "curved") - __outformats = ("png", "jpg", "svg", "pdf") + __outformats = ("png", "jpg", "svg", "pdf", "dot") # fmt: off _default_graph_attrs = { @@ -47,7 +48,7 @@ def __init__( filename: str = "", direction: str = "LR", curvestyle: str = "ortho", - outformat: str = "png", + outformats: list = ["png"], show: bool = True, graph_attr: dict = {}, node_attr: dict = {}, @@ -61,7 +62,7 @@ def __init__( If not given, it will be generated from the name. :param direction: Data flow direction. Default is 'left to right'. :param curvestyle: Curve bending style. One of "ortho" or "curved". - :param outformat: Output file format. Default is 'png'. + :param outformats: List of output file formats. Default is ['png']. :param show: Open generated image after save if true, just only save otherwise. :param graph_attr: Provide graph_attr dot config attributes. :param node_attr: Provide node_attr dot config attributes. @@ -93,9 +94,10 @@ def __init__( raise ValueError(f'"{curvestyle}" is not a valid curvestyle') self.dot.graph_attr["splines"] = curvestyle - if not self._validate_outformat(outformat): - raise ValueError(f'"{outformat}" is not a valid output format') - self.outformat = outformat + for outformat in outformats: + if not self._validate_outformat(outformat): + raise ValueError(f'"{outformat}" is not a valid output format') + self.outformats = outformats # Merge passed in attributes self.dot.graph_attr.update(graph_attr) @@ -140,4 +142,5 @@ def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs) def render(self) -> None: - self.dot.render(format=self.outformat, view=self.show, quiet=True) + for outformat in self.outformats: + self.dot.render(format=outformat, view=self.show, quiet=True) diff --git a/tests/test_diagram.py b/tests/test_diagram.py index 00bdacc6e..9455e9c29 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -47,12 +47,12 @@ def test_validate_curvestyle(self): def test_validate_outformat(self): # Normal output formats. for fmt in ("png", "jpg", "svg", "pdf", "PNG", "dot"): - Diagram(outformat=fmt) + Diagram(outformats=[fmt]) # Invalid output formats. for fmt in ("pnp", "jpe", "unknown"): with self.assertRaises(ValueError): - Diagram(outformat=fmt) + Diagram(outformats=[fmt]) def test_with_global_context(self): self.assertIsNone(getdiagram()) @@ -115,12 +115,15 @@ def test_autolabel(self): def test_outformat_list(self): - """Check that outformat render all the files from the list.""" + """Check that outformats render all the files from the list.""" self.name = 'diagrams_image' - with Diagram(show=False, outformat=["dot", "png"]): + with Diagram(show=False, outformats=["dot", "png", "jpg", "svg", "pdf"]): Node("node1") # both files must exist self.assertTrue(os.path.exists(f"{self.name}.png")) + self.assertTrue(os.path.exists(f"{self.name}.jpg")) + self.assertTrue(os.path.exists(f"{self.name}.svg")) + self.assertTrue(os.path.exists(f"{self.name}.pdf")) self.assertTrue(os.path.exists(f"{self.name}.dot")) # clean the dot file as it only generated here